[flang-commits] [flang] [Flang][OpenMP] Correct ArrayElements in Reduction Clause (PR #196094)
Jack Styles via flang-commits
flang-commits at lists.llvm.org
Wed May 6 08:21:23 PDT 2026
https://github.com/Stylie777 created https://github.com/llvm/llvm-project/pull/196094
Currently, when an ArrayElement is used within a Reduction clause, it will be lowered with the reduction referencing the box containing the array, not just the element.
To address this, adjust Flang lowering to track expressions alongside symbol to ensure that the Array Element context is not lost and considered when lowering a reduction with Array Element. This ensures that, when represented in HLFIR, it will be just the element's type, rather than the full array.
Currently this excludes DO CONCURRENT as it excludes Array Elements, and is limited to Array Elements but there are options to expand this into Array Sections in the future.
Assisted-by: Codex
>From dab10b236e63e977c0e4dc14985d9da89b798f13 Mon Sep 17 00:00:00 2001
From: Jack Styles <jack.styles at arm.com>
Date: Thu, 26 Feb 2026 09:58:06 +0000
Subject: [PATCH] [Flang][OpenMP] Reduce ArrayElements in Reduction Clause
Currently, when an ArrayElement is used within a Reduction clause,
it will be lowered with the reduction referencing the box containing
the array, not just the element.
To address this, adjust Flang lowering to track expressions alongside
symbol to ensure that the Array Element context is not lost and
considered when lowering a reduction with Array Element. This ensures
that, when represented in HLFIR, it will be just the element's type,
rather than the full array.
Currently this excludes DO CONCURRENT as it excludes Array Elements,
and is limited to Array Elements but there are options to expand this
into Array Sections in the future.
Assisted-by: Codex
---
.../flang/Lower/Support/ReductionProcessor.h | 19 +-
flang/include/flang/Support/OpenMP-utils.h | 5 +-
flang/lib/Lower/Bridge.cpp | 4 +-
flang/lib/Lower/ConvertExprToHLFIR.cpp | 31 +-
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 35 +-
flang/lib/Lower/OpenMP/ClauseProcessor.h | 15 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 325 ++++++++++++++----
.../lib/Lower/Support/ReductionProcessor.cpp | 163 +++++----
.../Lower/OpenMP/reduction-array-element.f90 | 114 ++++++
9 files changed, 558 insertions(+), 153 deletions(-)
create mode 100644 flang/test/Lower/OpenMP/reduction-array-element.f90
diff --git a/flang/include/flang/Lower/Support/ReductionProcessor.h b/flang/include/flang/Lower/Support/ReductionProcessor.h
index 0b4a692827a79..a1dab8fbc4d5e 100644
--- a/flang/include/flang/Lower/Support/ReductionProcessor.h
+++ b/flang/include/flang/Lower/Support/ReductionProcessor.h
@@ -38,6 +38,11 @@ namespace Fortran {
namespace lower {
namespace omp {
+struct ReductionValueCache {
+ llvm::DenseMap<const semantics::Symbol *, mlir::Value> symbolCache;
+ lower::ExprToValueMap exprCache;
+};
+
class ReductionProcessor {
public:
// ompOrig: mold/original variable
@@ -145,11 +150,11 @@ class ReductionProcessor {
/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
/// \param [in,out] reductionVarCache - optional cache mapping reduction
- /// symbols to their SSA values. When provided, array/box reduction
- /// variables that have already been allocated will be reused instead of
- /// creating new allocas. This ensures that nested composite wrappers
- /// (e.g. wsloop and simd in DO SIMD) share the same SSA values, allowing
- /// the genLoopVars() mapper to correctly remap inner wrapper operands.
+ /// objects to their SSA values. Scalar array elements are keyed by
+ /// expression, while whole-symbol reductions are keyed by symbol. This
+ /// ensures that nested composite wrappers (e.g. wsloop and simd in DO SIMD)
+ /// share the same SSA values without conflating distinct element
+ /// expressions of the same base symbol.
template <typename OpType, typename RedOperatorListTy>
static bool processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
@@ -158,8 +163,8 @@ class ReductionProcessor {
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value>
- *reductionVarCache = nullptr);
+ const llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symMap, ReductionValueCache *reductionVarCache = nullptr);
};
template <typename FloatOp, typename IntegerOp>
diff --git a/flang/include/flang/Support/OpenMP-utils.h b/flang/include/flang/Support/OpenMP-utils.h
index 6d9db2b682c50..435215b887de9 100644
--- a/flang/include/flang/Support/OpenMP-utils.h
+++ b/flang/include/flang/Support/OpenMP-utils.h
@@ -9,6 +9,7 @@
#ifndef FORTRAN_SUPPORT_OPENMP_UTILS_H_
#define FORTRAN_SUPPORT_OPENMP_UTILS_H_
+#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Semantics/symbol.h"
#include "mlir/IR/Builders.h"
@@ -22,12 +23,14 @@ namespace Fortran::common::openmp {
struct EntryBlockArgsEntry {
llvm::ArrayRef<const Fortran::semantics::Symbol *> syms;
llvm::ArrayRef<mlir::Value> vars;
+ llvm::ArrayRef<const Fortran::lower::omp::Object *> objs;
bool isValid() const {
// This check allows specifying a smaller number of symbols than values
// because in some case cases a single symbol generates multiple block
// arguments.
- return syms.size() <= vars.size();
+ return syms.size() <= vars.size() &&
+ (objs.empty() || objs.size() == syms.size());
}
};
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index c5709e1cd94d4..42e14d6a4dea3 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -128,6 +128,7 @@ struct IncrementLoopInfo {
llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> reduceSymList;
+ llvm::SmallVector<const Fortran::lower::omp::Object *> reduceObjList;
llvm::SmallVector<fir::ReduceOperationEnum> reduceOperatorList;
llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
mlir::Value loopVariable = nullptr;
@@ -2413,7 +2414,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::lower::omp::ReductionProcessor rp;
bool result = rp.processReductionArguments<fir::DeclareReductionOp>(
toLocation(), *this, info.reduceOperatorList, reduceVars,
- reduceVarByRef, reductionDeclSymbols, info.reduceSymList);
+ reduceVarByRef, reductionDeclSymbols, info.reduceSymList,
+ info.reduceObjList, getSymbolMap());
if (!result)
TODO(toLocation(), "Lowering unrecognised reduction type");
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index a57fce53c0ca5..12188706575e0 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1592,6 +1592,29 @@ static bool hasDeferredCharacterLength(const Fortran::semantics::Symbol &sym) {
type->characterTypeSpec().length().isDeferred();
}
+static mlir::Value
+findOverriddenExprValue(const Fortran::lower::ExprToValueMap &map,
+ const Fortran::lower::SomeExpr &expr) {
+ if (auto match = map.find(&expr); match != map.end())
+ return match->second;
+
+ if (!Fortran::evaluate::IsArrayElement(expr))
+ return {};
+
+ for (auto [key, value] : map) {
+ if (key == llvm::DenseMapInfo<
+ const Fortran::lower::SomeExpr *>::getEmptyKey() ||
+ key == llvm::DenseMapInfo<
+ const Fortran::lower::SomeExpr *>::getTombstoneKey())
+ continue;
+ if (Fortran::evaluate::IsArrayElement(*key) &&
+ key->AsFortran() == expr.AsFortran())
+ return value;
+ }
+
+ return {};
+}
+
/// Lower Expr to HLFIR.
class HlfirBuilder {
public:
@@ -1605,12 +1628,12 @@ class HlfirBuilder {
if (const Fortran::lower::ExprToValueMap *map =
getConverter().getExprOverrides()) {
if constexpr (std::is_same_v<T, Fortran::evaluate::SomeType>) {
- if (auto match = map->find(&expr); match != map->end())
- return hlfir::EntityWithAttributes{match->second};
+ if (mlir::Value value = findOverriddenExprValue(*map, expr))
+ return hlfir::EntityWithAttributes{value};
} else {
Fortran::lower::SomeExpr someExpr = toEvExpr(expr);
- if (auto match = map->find(&someExpr); match != map->end())
- return hlfir::EntityWithAttributes{match->second};
+ if (mlir::Value value = findOverriddenExprValue(*map, someExpr))
+ return hlfir::EntityWithAttributes{value};
}
}
return Fortran::common::visit([&](const auto &x) { return gen(x); },
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 1c39e90a922cf..0d317fd6fa496 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1498,31 +1498,36 @@ bool ClauseProcessor::processIf(
template <typename T>
void collectReductionSyms(
const T &reduction,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
+ llvm::SmallVectorImpl<const Object *> &reductionObjs) {
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
for (const Object &object : objectList) {
const semantics::Symbol *symbol = object.sym();
reductionSyms.push_back(symbol);
+ reductionObjs.push_back(&object);
}
}
bool ClauseProcessor::processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable, ReductionValueCache *reductionVarCache) const {
return findRepeatableClause<omp::clause::InReduction>(
[&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> inReductionVars;
llvm::SmallVector<bool> inReduceVarByRef;
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
- collectReductionSyms(clause, inReductionSyms);
+ llvm::SmallVector<const Object *> inReductionObjs;
+ collectReductionSyms(clause, inReductionSyms, inReductionObjs);
ReductionProcessor rp;
if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
currentLocation, converter,
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
inReductionVars, inReduceVarByRef, inReductionDeclSymbols,
- inReductionSyms))
+ inReductionSyms, inReductionObjs, symTable, reductionVarCache))
TODO(currentLocation, "Lowering unrecognised reduction type");
// Copy local lists into the output.
@@ -1532,6 +1537,7 @@ bool ClauseProcessor::processInReduction(
llvm::copy(inReductionDeclSymbols,
std::back_inserter(result.inReductionSyms));
llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
+ llvm::copy(inReductionObjs, std::back_inserter(outReductionObjs));
});
}
@@ -2024,15 +2030,16 @@ bool ClauseProcessor::processNontemporal(
bool ClauseProcessor::processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value> *reductionVarCache)
- const {
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable, ReductionValueCache *reductionVarCache) const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> reductionVars;
llvm::SmallVector<bool> reduceVarByRef;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
- collectReductionSyms(clause, reductionSyms);
+ llvm::SmallVector<const Object *> reductionObjs;
+ collectReductionSyms(clause, reductionSyms, reductionObjs);
auto mod = std::get<std::optional<ReductionModifier>>(clause.t);
if (mod.has_value()) {
@@ -2049,7 +2056,7 @@ bool ClauseProcessor::processReduction(
currentLocation, converter,
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
reductionVars, reduceVarByRef, reductionDeclSymbols,
- reductionSyms, reductionVarCache))
+ reductionSyms, reductionObjs, symTable, reductionVarCache))
TODO(currentLocation, "Lowering unrecognised reduction type");
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
@@ -2057,26 +2064,31 @@ bool ClauseProcessor::processReduction(
llvm::copy(reductionDeclSymbols,
std::back_inserter(result.reductionSyms));
llvm::copy(reductionSyms, std::back_inserter(outReductionSyms));
+ llvm::copy(reductionObjs, std::back_inserter(outReductionObjs));
});
}
bool ClauseProcessor::processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable, ReductionValueCache *reductionVarCache) const {
return findRepeatableClause<omp::clause::TaskReduction>(
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> taskReductionVars;
llvm::SmallVector<bool> taskReduceVarByRef;
llvm::SmallVector<mlir::Attribute> taskReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
- collectReductionSyms(clause, taskReductionSyms);
+ llvm::SmallVector<const Object *> taskReductionObjs;
+ collectReductionSyms(clause, taskReductionSyms, taskReductionObjs);
ReductionProcessor rp;
if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
currentLocation, converter,
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
taskReductionVars, taskReduceVarByRef, taskReductionDeclSymbols,
- taskReductionSyms))
+ taskReductionSyms, taskReductionObjs, symTable,
+ reductionVarCache))
TODO(currentLocation, "Lowering unrecognised reduction type");
// Copy local lists into the output.
llvm::copy(taskReductionVars,
@@ -2086,6 +2098,7 @@ bool ClauseProcessor::processTaskReduction(
llvm::copy(taskReductionDeclSymbols,
std::back_inserter(result.taskReductionSyms));
llvm::copy(taskReductionSyms, std::back_inserter(outReductionSyms));
+ llvm::copy(taskReductionObjs, std::back_inserter(outReductionObjs));
});
}
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index acf1068efb987..b34b531355e1d 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -142,7 +142,10 @@ class ClauseProcessor {
mlir::omp::IfClauseOps &result) const;
bool processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable,
+ ReductionValueCache *reductionVarCache = nullptr) const;
bool processIsDevicePtr(
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
@@ -167,11 +170,15 @@ class ClauseProcessor {
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value>
- *reductionVarCache = nullptr) const;
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable,
+ ReductionValueCache *reductionVarCache = nullptr) const;
bool processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable,
+ ReductionValueCache *reductionVarCache = nullptr) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 88d28cf94b045..0484464460225 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -75,6 +75,20 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
mlir::Location loc);
namespace {
+static bool isReductionObjectExpression(const Object *object) {
+ if (!object || !object->ref())
+ return false;
+ const SomeExpr &expr = *object->ref();
+ return evaluate::IsArrayElement(expr);
+}
+
+static std::optional<const SomeExpr *>
+getReductionObjectExpr(const Object *object) {
+ if (!isReductionObjectExpression(object))
+ return std::nullopt;
+ return &object->ref().value();
+}
+
/// Structure holding information that is needed to pass host-evaluated
/// information to later lowering stages.
class HostEvalInfo {
@@ -319,22 +333,52 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
bindSingleMapLike(*sym, arg);
};
- auto bindPrivateLike = [&converter, &firOpBuilder](
+ llvm::SmallPtrSet<const semantics::Symbol *, 8> objectReductionSyms;
+ auto collectObjectReductionSyms =
+ [&objectReductionSyms](llvm::ArrayRef<const Object *> objs) {
+ for (const Object *obj : objs)
+ if (isReductionObjectExpression(obj))
+ objectReductionSyms.insert(&obj->sym()->GetUltimate());
+ };
+ collectObjectReductionSyms(args.inReduction.objs);
+ collectObjectReductionSyms(args.reduction.objs);
+ collectObjectReductionSyms(args.taskReduction.objs);
+
+ auto bindPrivateLike = [&converter, &firOpBuilder, &objectReductionSyms](
llvm::ArrayRef<const semantics::Symbol *> syms,
llvm::ArrayRef<mlir::Value> vars,
- llvm::ArrayRef<mlir::BlockArgument> args) {
+ llvm::ArrayRef<mlir::BlockArgument> args,
+ llvm::ArrayRef<const Object *> objs,
+ bool skipObjectReductionSyms = false) {
+ assert((objs.empty() || objs.size() == syms.size()) &&
+ "invalid object list for private-like clause");
llvm::SmallVector<const semantics::Symbol *> processedSyms;
- for (auto *sym : syms) {
+ llvm::SmallVector<const Object *> processedObjs;
+ for (auto [idx, sym] : llvm::enumerate(syms)) {
+ const Object *obj = objs.empty() ? nullptr : objs[idx];
if (const auto *commonDet =
sym->detailsIf<semantics::CommonBlockDetails>()) {
- llvm::transform(commonDet->objects(), std::back_inserter(processedSyms),
- [&](const auto &mem) { return &*mem; });
+ for (auto &mem : commonDet->objects()) {
+ processedSyms.push_back(&*mem);
+ processedObjs.push_back(obj);
+ }
} else {
processedSyms.push_back(sym);
+ processedObjs.push_back(obj);
}
}
- for (auto [sym, var, arg] : llvm::zip_equal(processedSyms, vars, args))
+ assert(processedSyms.size() == processedObjs.size());
+ for (auto [sym, var, arg, obj] :
+ llvm::zip_equal(processedSyms, vars, args, processedObjs)) {
+ bool skipBind =
+ isReductionObjectExpression(obj) ||
+ (obj && sym->Rank() > 0 && !fir::unwrapUntilSeqType(arg.getType())) ||
+ (skipObjectReductionSyms &&
+ objectReductionSyms.contains(&sym->GetUltimate()));
+ if (skipBind)
+ continue;
+
converter.bindSymbol(
*sym,
hlfir::translateToExtendedValue(
@@ -342,6 +386,7 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
/*contiguousHint=*/
evaluate::IsSimplyContiguous(*sym, converter.getFoldingContext()))
.first);
+ }
};
// Process in clause name alphabetical order to match block arguments order.
@@ -349,13 +394,14 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
// corresponding region, except for very specific cases handled separately.
bindMapLike(args.hasDeviceAddr.syms, op.getHasDeviceAddrBlockArgs());
bindPrivateLike(args.inReduction.syms, args.inReduction.vars,
- op.getInReductionBlockArgs());
+ op.getInReductionBlockArgs(), args.inReduction.objs);
bindMapLike(args.map.syms, op.getMapBlockArgs());
- bindPrivateLike(args.priv.syms, args.priv.vars, op.getPrivateBlockArgs());
+ bindPrivateLike(args.priv.syms, args.priv.vars, op.getPrivateBlockArgs(),
+ args.priv.objs, /*skipObjectReductionSyms=*/true);
bindPrivateLike(args.reduction.syms, args.reduction.vars,
- op.getReductionBlockArgs());
+ op.getReductionBlockArgs(), args.reduction.objs);
bindPrivateLike(args.taskReduction.syms, args.taskReduction.vars,
- op.getTaskReductionBlockArgs());
+ op.getTaskReductionBlockArgs(), args.taskReduction.objs);
bindMapLike(args.useDeviceAddr.syms, op.getUseDeviceAddrBlockArgs());
bindMapLike(args.useDevicePtr.syms, op.getUseDevicePtrBlockArgs());
}
@@ -866,6 +912,7 @@ getDeclareTargetFunctionDevice(
/// \param [in] args - symbols of induction variables.
/// \param [in] wrapperArgs - list of parent loop wrappers and their associated
/// entry block arguments.
+
static void genLoopVars(
mlir::Operation *op, lower::AbstractConverter &converter,
mlir::Location &loc, llvm::ArrayRef<const semantics::Symbol *> args,
@@ -1170,6 +1217,36 @@ struct OpWithBodyGenInfo {
bool privatize = true;
};
+static mlir::Value getReductionOverrideValue(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ const Object *object,
+ mlir::BlockArgument arg) {
+ if (hlfir::isFortranEntityWithAttributes(arg))
+ return arg;
+
+ fir::FortranVariableFlagsAttr attributes;
+ llvm::SmallVector<mlir::Value> typeParams;
+ auto declareOp = hlfir::DeclareOp::create(
+ builder, loc, arg, "omp.reduction.element", nullptr, typeParams, nullptr,
+ nullptr, 0, attributes);
+ return declareOp.getBase();
+}
+
+static void
+addReductionObjectOverrides(fir::FirOpBuilder &builder, mlir::Location loc,
+ lower::ExprToValueMap &overrides,
+ const EntryBlockArgsEntry &entry,
+ llvm::ArrayRef<mlir::BlockArgument> blockArgs) {
+ if (entry.objs.empty())
+ return;
+
+ assert(entry.objs.size() == blockArgs.size() &&
+ "reduction object list must match block arguments");
+ for (auto [object, arg] : llvm::zip_equal(entry.objs, blockArgs))
+ if (std::optional<const SomeExpr *> expr = getReductionObjectExpr(object))
+ overrides[*expr] = getReductionOverrideValue(builder, loc, object, arg);
+}
+
/// Create the body (block) for an OpenMP Operation.
///
/// \param [in] op - the operation the body belongs to.
@@ -1249,6 +1326,27 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
}
if (!info.genSkeletonOnly) {
+ lower::ExprToValueMap local;
+ if (auto *old = info.converter.getExprOverrides())
+ local.insert(old->begin(), old->end());
+ if (info.blockArgs) {
+ if (auto ompBlockArgOp =
+ mlir::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(op)) {
+ addReductionObjectOverrides(firOpBuilder, info.loc, local,
+ info.blockArgs->inReduction,
+ ompBlockArgOp.getInReductionBlockArgs());
+ addReductionObjectOverrides(firOpBuilder, info.loc, local,
+ info.blockArgs->reduction,
+ ompBlockArgOp.getReductionBlockArgs());
+ addReductionObjectOverrides(firOpBuilder, info.loc, local,
+ info.blockArgs->taskReduction,
+ ompBlockArgOp.getTaskReductionBlockArgs());
+ }
+ }
+
+ auto *old = info.converter.getExprOverrides();
+ info.converter.overrideExprValues(local.empty() ? old : &local);
+
if (ConstructQueue::const_iterator next = std::next(item);
next != queue.end()) {
genOMPDispatch(info.converter, info.symTable, info.semaCtx, info.eval,
@@ -1264,6 +1362,8 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
genNestedEvaluations(info.converter, info.eval);
temp->erase();
}
+
+ info.converter.overrideExprValues(old);
}
// Get or create a unique exiting block from the given region, or
@@ -1595,15 +1695,18 @@ genLoopNestClauses(lower::AbstractConverter &converter,
cp.processTileSizes(eval, clauseOps);
}
-static void genLoopClauses(
- lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
- const List<Clause> &clauses, mlir::Location loc,
- mlir::omp::LoopOperands &clauseOps,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
+static void
+genLoopClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ mlir::omp::LoopOperands &clauseOps,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
+ llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symTable) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processBind(clauseOps);
cp.processOrder(clauseOps);
- cp.processReduction(loc, clauseOps, reductionSyms);
+ cp.processReduction(loc, clauseOps, reductionSyms, reductionObjs, symTable);
cp.processTODO<clause::Lastprivate>(loc, llvm::omp::Directive::OMPD_loop);
}
@@ -1629,7 +1732,9 @@ static void genParallelClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, mlir::omp::ParallelOperands &clauseOps,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
+ llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symTable) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
@@ -1639,7 +1744,7 @@ static void genParallelClauses(
cp.processNumThreads(stmtCtx, clauseOps);
cp.processProcBind(clauseOps);
- cp.processReduction(loc, clauseOps, reductionSyms);
+ cp.processReduction(loc, clauseOps, reductionSyms, reductionObjs, symTable);
}
static void genScanClauses(lower::AbstractConverter &converter,
@@ -1655,11 +1760,13 @@ static void genSectionsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::SectionsOperands &clauseOps,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
+ llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symTable) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processNowait(clauseOps);
- cp.processReduction(loc, clauseOps, reductionSyms);
+ cp.processReduction(loc, clauseOps, reductionSyms, reductionObjs, symTable);
// TODO Support delayed privatization.
}
@@ -1668,14 +1775,15 @@ static void genSimdClauses(
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::SimdOperands &clauseOps,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value> *reductionVarCache =
- nullptr) {
+ llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symTable, ReductionValueCache *reductionVarCache = nullptr) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAligned(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps);
cp.processNontemporal(clauseOps);
cp.processOrder(clauseOps);
- cp.processReduction(loc, clauseOps, reductionSyms, reductionVarCache);
+ cp.processReduction(loc, clauseOps, reductionSyms, reductionObjs, symTable,
+ reductionVarCache);
cp.processSafelen(clauseOps);
cp.processSimdlen(clauseOps);
cp.processLinear(clauseOps);
@@ -1751,15 +1859,17 @@ genSimdImplicitLinear(lower::AbstractConverter &converter,
}
}
-static void genScopeClauses(
- lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
- const List<Clause> &clauses, mlir::Location loc,
- mlir::omp::ScopeOperands &clauseOps,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
+static void
+genScopeClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ lower::SymMap &symTable, mlir::omp::ScopeOperands &clauseOps,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
+ llvm::SmallVectorImpl<const Object *> &reductionObjs) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processNowait(clauseOps);
- cp.processReduction(loc, clauseOps, reductionSyms);
+ cp.processReduction(loc, clauseOps, reductionSyms, reductionObjs, symTable);
}
static void genSingleClauses(lower::AbstractConverter &converter,
@@ -1859,14 +1969,17 @@ static void genTaskClauses(
lower::SymMap &symTable, lower::StatementContext &stmtCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskOperands &clauseOps,
- llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &inReductionObjs,
+ ReductionValueCache *reductionVarCache = nullptr) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAffinity(clauseOps);
cp.processAllocate(clauseOps);
cp.processDepend(symTable, stmtCtx, clauseOps);
cp.processFinal(stmtCtx, clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
- cp.processInReduction(loc, clauseOps, inReductionSyms);
+ cp.processInReduction(loc, clauseOps, inReductionSyms, inReductionObjs,
+ symTable, reductionVarCache);
cp.processMergeable(clauseOps);
cp.processPriority(stmtCtx, clauseOps);
cp.processUntied(clauseOps);
@@ -1877,10 +1990,13 @@ static void genTaskgroupClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupOperands &clauseOps,
- llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &taskReductionObjs,
+ lower::SymMap &symTable) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
- cp.processTaskReduction(loc, clauseOps, taskReductionSyms);
+ cp.processTaskReduction(loc, clauseOps, taskReductionSyms, taskReductionObjs,
+ symTable);
}
static void genTaskloopClauses(
@@ -1888,19 +2004,23 @@ static void genTaskloopClauses(
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, mlir::omp::TaskloopContextOperands &clauseOps,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
- llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
+ llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &inReductionObjs,
+ lower::SymMap &symTable) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processFinal(stmtCtx, clauseOps);
cp.processGrainsize(stmtCtx, clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_taskloop, clauseOps);
- cp.processInReduction(loc, clauseOps, inReductionSyms);
+ cp.processInReduction(loc, clauseOps, inReductionSyms, inReductionObjs,
+ symTable);
cp.processMergeable(clauseOps);
cp.processNogroup(clauseOps);
cp.processNumTasks(stmtCtx, clauseOps);
cp.processPriority(stmtCtx, clauseOps);
- cp.processReduction(loc, clauseOps, reductionSyms);
+ cp.processReduction(loc, clauseOps, reductionSyms, reductionObjs, symTable);
cp.processUntied(clauseOps);
}
@@ -1922,11 +2042,14 @@ static void genWorkshareClauses(lower::AbstractConverter &converter,
cp.processNowait(clauseOps);
}
-static void genTeamsClauses(
- lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
- lower::StatementContext &stmtCtx, const List<Clause> &clauses,
- mlir::Location loc, mlir::omp::TeamsOperands &clauseOps,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
+static void
+genTeamsClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, mlir::omp::TeamsOperands &clauseOps,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
+ llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symTable) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
@@ -1937,7 +2060,7 @@ static void genTeamsClauses(
cp.processThreadLimit(stmtCtx, clauseOps);
}
- cp.processReduction(loc, clauseOps, reductionSyms);
+ cp.processReduction(loc, clauseOps, reductionSyms, reductionObjs, symTable);
// TODO Support delayed privatization.
}
@@ -1946,14 +2069,15 @@ static void genWsloopClauses(
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, mlir::omp::WsloopOperands &clauseOps,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value> *reductionVarCache =
- nullptr) {
+ llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symTable, ReductionValueCache *reductionVarCache = nullptr) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processNowait(clauseOps);
cp.processOrder(clauseOps);
cp.processOrdered(clauseOps);
- cp.processReduction(loc, clauseOps, reductionSyms, reductionVarCache);
+ cp.processReduction(loc, clauseOps, reductionSyms, reductionObjs, symTable,
+ reductionVarCache);
cp.processSchedule(stmtCtx, clauseOps);
cp.processLinear(clauseOps);
}
@@ -2074,21 +2198,40 @@ static mlir::omp::LoopNestOp genLoopNestOp(
std::pair<mlir::omp::BlockArgOpenMPOpInterface, const EntryBlockArgs &>>
wrapperArgs,
llvm::omp::Directive directive, DataSharingProcessor &dsp) {
+ const lower::ExprToValueMap *oldOverrides = converter.getExprOverrides();
+ lower::ExprToValueMap loopNestOverrides;
auto ivCallback = [&](mlir::Operation *op) {
genLoopVars(op, converter, loc, iv, wrapperArgs);
+ if (oldOverrides)
+ loopNestOverrides.insert(oldOverrides->begin(), oldOverrides->end());
+ for (auto [argGeneratingOp, blockArgs] : wrapperArgs) {
+ addReductionObjectOverrides(converter.getFirOpBuilder(), loc,
+ loopNestOverrides, blockArgs.inReduction,
+ argGeneratingOp.getInReductionBlockArgs());
+ addReductionObjectOverrides(converter.getFirOpBuilder(), loc,
+ loopNestOverrides, blockArgs.reduction,
+ argGeneratingOp.getReductionBlockArgs());
+ addReductionObjectOverrides(converter.getFirOpBuilder(), loc,
+ loopNestOverrides, blockArgs.taskReduction,
+ argGeneratingOp.getTaskReductionBlockArgs());
+ }
+ converter.overrideExprValues(
+ loopNestOverrides.empty() ? oldOverrides : &loopNestOverrides);
return llvm::SmallVector<const semantics::Symbol *>(iv);
};
uint64_t nestValue = getCollapseValue(item->clauses);
nestValue = nestValue < iv.size() ? iv.size() : nestValue;
auto *nestedEval = getCollapsedLoopEval(eval, nestValue);
- return genOpWithBody<mlir::omp::LoopNestOp>(
+ auto loopNestOp = genOpWithBody<mlir::omp::LoopNestOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval,
directive)
.setClauses(&item->clauses)
.setDataSharingProcessor(&dsp)
.setGenRegionEntryCb(ivCallback),
queue, item, clauseOps);
+ converter.overrideExprValues(oldOverrides);
+ return loopNestOp;
}
static mlir::omp::LoopOp
@@ -2098,8 +2241,9 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::const_iterator item) {
mlir::omp::LoopOperands loopClauseOps;
llvm::SmallVector<const semantics::Symbol *> loopReductionSyms;
+ llvm::SmallVector<const Object *> loopReductionObjs;
genLoopClauses(converter, semaCtx, item->clauses, loc, loopClauseOps,
- loopReductionSyms);
+ loopReductionSyms, loopReductionObjs, symTable);
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
@@ -2116,6 +2260,7 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
loopArgs.priv.vars = loopClauseOps.privateVars;
loopArgs.reduction.syms = loopReductionSyms;
loopArgs.reduction.vars = loopClauseOps.reductionVars;
+ loopArgs.reduction.objs = loopReductionObjs;
auto loopOp =
genWrapperOp<mlir::omp::LoopOp>(converter, loc, loopClauseOps, loopArgs);
@@ -2561,8 +2706,9 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
std::get<std::list<parser::OpenMPConstruct>>(sectionsConstruct->t);
mlir::omp::SectionsOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
+ llvm::SmallVector<const Object *> reductionObjs;
genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
- reductionSyms);
+ reductionSyms, reductionObjs, symTable);
auto &builder = converter.getFirOpBuilder();
@@ -2601,6 +2747,7 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// TODO: Add private syms and vars.
args.reduction.syms = reductionSyms;
args.reduction.vars = clauseOps.reductionVars;
+ args.reduction.objs = reductionObjs;
genEntryBlock(builder, args, sectionsOp.getRegion());
mlir::Operation *terminator =
@@ -2683,8 +2830,9 @@ genScopeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
lower::SymMapScope scope(symTable);
mlir::omp::ScopeOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
- genScopeClauses(converter, semaCtx, item->clauses, loc, clauseOps,
- reductionSyms);
+ llvm::SmallVector<const Object *> reductionObjs;
+ genScopeClauses(converter, semaCtx, item->clauses, loc, symTable, clauseOps,
+ reductionSyms, reductionObjs);
std::optional<DataSharingProcessor> dsp;
if (enableDelayedPrivatization) {
@@ -2700,6 +2848,7 @@ genScopeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
args.priv.vars = clauseOps.privateVars;
args.reduction.syms = reductionSyms;
args.reduction.vars = clauseOps.reductionVars;
+ args.reduction.objs = reductionObjs;
return genOpWithBody<mlir::omp::ScopeOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
@@ -3094,8 +3243,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::const_iterator item) {
mlir::omp::TaskOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
+ llvm::SmallVector<const Object *> inReductionObjs;
genTaskClauses(converter, semaCtx, symTable, stmtCtx, item->clauses, loc,
- clauseOps, inReductionSyms);
+ clauseOps, inReductionSyms, inReductionObjs);
if (!enableDelayedPrivatization)
return genOpWithBody<mlir::omp::TaskOp>(
@@ -3114,6 +3264,7 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
taskArgs.priv.vars = clauseOps.privateVars;
taskArgs.inReduction.syms = inReductionSyms;
taskArgs.inReduction.vars = clauseOps.inReductionVars;
+ taskArgs.inReduction.objs = inReductionObjs;
return genOpWithBody<mlir::omp::TaskOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
@@ -3132,12 +3283,14 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::const_iterator item) {
mlir::omp::TaskgroupOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
+ llvm::SmallVector<const Object *> taskReductionObjs;
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
- taskReductionSyms);
+ taskReductionSyms, taskReductionObjs, symTable);
EntryBlockArgs taskgroupArgs;
taskgroupArgs.taskReduction.syms = taskReductionSyms;
taskgroupArgs.taskReduction.vars = clauseOps.taskReductionVars;
+ taskgroupArgs.taskReduction.objs = taskReductionObjs;
return genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
@@ -3193,13 +3346,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
lower::SymMapScope scope(symTable);
mlir::omp::TeamsOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
+ llvm::SmallVector<const Object *> reductionObjs;
genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
- reductionSyms);
+ reductionSyms, reductionObjs, symTable);
EntryBlockArgs args;
// TODO: Add private syms and vars.
args.reduction.syms = reductionSyms;
args.reduction.vars = clauseOps.reductionVars;
+ args.reduction.objs = reductionObjs;
return genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_teams)
@@ -3262,8 +3417,10 @@ static mlir::omp::WsloopOp genStandaloneDo(
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
mlir::omp::WsloopOperands wsloopClauseOps;
llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
+ llvm::SmallVector<const Object *> wsloopReductionObjs;
genWsloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
- wsloopClauseOps, wsloopReductionSyms);
+ wsloopClauseOps, wsloopReductionSyms, wsloopReductionObjs,
+ symTable);
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
@@ -3280,6 +3437,7 @@ static mlir::omp::WsloopOp genStandaloneDo(
wsloopArgs.priv.vars = wsloopClauseOps.privateVars;
wsloopArgs.reduction.syms = wsloopReductionSyms;
wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
+ wsloopArgs.reduction.objs = wsloopReductionObjs;
auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
converter, loc, wsloopClauseOps, wsloopArgs);
@@ -3297,8 +3455,10 @@ static mlir::omp::ParallelOp genStandaloneParallel(
lower::SymMapScope scope(symTable);
mlir::omp::ParallelOperands parallelClauseOps;
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
+ llvm::SmallVector<const Object *> parallelReductionObjs;
genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
- parallelClauseOps, parallelReductionSyms);
+ parallelClauseOps, parallelReductionSyms,
+ parallelReductionObjs, symTable);
std::optional<DataSharingProcessor> dsp;
if (enableDelayedPrivatization) {
@@ -3314,6 +3474,7 @@ static mlir::omp::ParallelOp genStandaloneParallel(
parallelArgs.priv.vars = parallelClauseOps.privateVars;
parallelArgs.reduction.syms = parallelReductionSyms;
parallelArgs.reduction.vars = parallelClauseOps.reductionVars;
+ parallelArgs.reduction.objs = parallelReductionObjs;
return genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item,
parallelClauseOps, parallelArgs,
enableDelayedPrivatization ? &dsp.value() : nullptr);
@@ -3327,8 +3488,9 @@ genStandaloneSimd(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::const_iterator item) {
mlir::omp::SimdOperands simdClauseOps;
llvm::SmallVector<const semantics::Symbol *> simdReductionSyms;
+ llvm::SmallVector<const Object *> simdReductionObjs;
genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps,
- simdReductionSyms);
+ simdReductionSyms, simdReductionObjs, symTable);
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
@@ -3347,6 +3509,7 @@ genStandaloneSimd(lower::AbstractConverter &converter, lower::SymMap &symTable,
simdArgs.priv.vars = simdClauseOps.privateVars;
simdArgs.reduction.syms = simdReductionSyms;
simdArgs.reduction.vars = simdClauseOps.reductionVars;
+ simdArgs.reduction.objs = simdReductionObjs;
auto simdOp =
genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, simdArgs);
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
@@ -3362,10 +3525,13 @@ static mlir::omp::TaskloopContextOp genStandaloneTaskloop(
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
mlir::omp::TaskloopContextOperands taskloopClauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
+ llvm::SmallVector<const Object *> reductionObjs;
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
+ llvm::SmallVector<const Object *> inReductionObjs;
genTaskloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
- taskloopClauseOps, reductionSyms, inReductionSyms);
+ taskloopClauseOps, reductionSyms, reductionObjs,
+ inReductionSyms, inReductionObjs, symTable);
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
enableDelayedPrivatization, symTable);
@@ -3381,8 +3547,10 @@ static mlir::omp::TaskloopContextOp genStandaloneTaskloop(
taskloopArgs.priv.vars = taskloopClauseOps.privateVars;
taskloopArgs.reduction.syms = reductionSyms;
taskloopArgs.reduction.vars = taskloopClauseOps.reductionVars;
+ taskloopArgs.reduction.objs = reductionObjs;
taskloopArgs.inReduction.syms = inReductionSyms;
taskloopArgs.inReduction.vars = taskloopClauseOps.inReductionVars;
+ taskloopArgs.inReduction.objs = inReductionObjs;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
auto taskLoopContextOp = mlir::omp::TaskloopContextOp::create(
@@ -3423,8 +3591,10 @@ static mlir::omp::DistributeOp genCompositeDistributeParallelDo(
// Create parent omp.parallel first.
mlir::omp::ParallelOperands parallelClauseOps;
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
+ llvm::SmallVector<const Object *> parallelReductionObjs;
genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc,
- parallelClauseOps, parallelReductionSyms);
+ parallelClauseOps, parallelReductionSyms,
+ parallelReductionObjs, symTable);
DataSharingProcessor dsp(converter, semaCtx, doItem->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
@@ -3436,6 +3606,7 @@ static mlir::omp::DistributeOp genCompositeDistributeParallelDo(
parallelArgs.priv.vars = parallelClauseOps.privateVars;
parallelArgs.reduction.syms = parallelReductionSyms;
parallelArgs.reduction.vars = parallelClauseOps.reductionVars;
+ parallelArgs.reduction.objs = parallelReductionObjs;
genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem,
parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true);
@@ -3446,8 +3617,10 @@ static mlir::omp::DistributeOp genCompositeDistributeParallelDo(
mlir::omp::WsloopOperands wsloopClauseOps;
llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
+ llvm::SmallVector<const Object *> wsloopReductionObjs;
genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
- wsloopClauseOps, wsloopReductionSyms);
+ wsloopClauseOps, wsloopReductionSyms, wsloopReductionObjs,
+ symTable);
mlir::omp::LoopNestOperands loopNestClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
@@ -3465,6 +3638,7 @@ static mlir::omp::DistributeOp genCompositeDistributeParallelDo(
// TODO: Add private syms and vars.
wsloopArgs.reduction.syms = wsloopReductionSyms;
wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
+ wsloopArgs.reduction.objs = wsloopReductionObjs;
auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
converter, loc, wsloopClauseOps, wsloopArgs);
wsloopOp.setComposite(/*val=*/true);
@@ -3490,8 +3664,10 @@ static mlir::omp::DistributeOp genCompositeDistributeParallelDoSimd(
// Create parent omp.parallel first.
mlir::omp::ParallelOperands parallelClauseOps;
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
+ llvm::SmallVector<const Object *> parallelReductionObjs;
genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc,
- parallelClauseOps, parallelReductionSyms);
+ parallelClauseOps, parallelReductionSyms,
+ parallelReductionObjs, symTable);
DataSharingProcessor parallelItemDSP(
converter, semaCtx, parallelItem->clauses, eval,
@@ -3504,6 +3680,7 @@ static mlir::omp::DistributeOp genCompositeDistributeParallelDoSimd(
parallelArgs.priv.vars = parallelClauseOps.privateVars;
parallelArgs.reduction.syms = parallelReductionSyms;
parallelArgs.reduction.vars = parallelClauseOps.reductionVars;
+ parallelArgs.reduction.objs = parallelReductionObjs;
genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem,
parallelClauseOps, parallelArgs, ¶llelItemDSP,
/*isComposite=*/true);
@@ -3511,7 +3688,7 @@ static mlir::omp::DistributeOp genCompositeDistributeParallelDoSimd(
// Clause processing.
// Use a shared cache so that both wsloop and simd produce the same SSA
// values for array/box reduction variables. See genCompositeDoSimd.
- llvm::DenseMap<const semantics::Symbol *, mlir::Value> reductionVarCache;
+ ReductionValueCache reductionVarCache;
mlir::omp::DistributeOperands distributeClauseOps;
genDistributeClauses(converter, semaCtx, stmtCtx, distributeItem->clauses,
@@ -3519,13 +3696,17 @@ static mlir::omp::DistributeOp genCompositeDistributeParallelDoSimd(
mlir::omp::WsloopOperands wsloopClauseOps;
llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
+ llvm::SmallVector<const Object *> wsloopReductionObjs;
genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
- wsloopClauseOps, wsloopReductionSyms, &reductionVarCache);
+ wsloopClauseOps, wsloopReductionSyms, wsloopReductionObjs,
+ symTable, &reductionVarCache);
mlir::omp::SimdOperands simdClauseOps;
llvm::SmallVector<const semantics::Symbol *> simdReductionSyms;
+ llvm::SmallVector<const Object *> simdReductionObjs;
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps,
- simdReductionSyms, &reductionVarCache);
+ simdReductionSyms, simdReductionObjs, symTable,
+ &reductionVarCache);
DataSharingProcessor simdItemDSP(converter, semaCtx, simdItem->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
@@ -3550,6 +3731,7 @@ static mlir::omp::DistributeOp genCompositeDistributeParallelDoSimd(
// TODO: Add private syms and vars.
wsloopArgs.reduction.syms = wsloopReductionSyms;
wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
+ wsloopArgs.reduction.objs = wsloopReductionObjs;
auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
converter, loc, wsloopClauseOps, wsloopArgs);
wsloopOp.setComposite(/*val=*/true);
@@ -3559,6 +3741,7 @@ static mlir::omp::DistributeOp genCompositeDistributeParallelDoSimd(
simdArgs.priv.vars = simdClauseOps.privateVars;
simdArgs.reduction.syms = simdReductionSyms;
simdArgs.reduction.vars = simdClauseOps.reductionVars;
+ simdArgs.reduction.objs = simdReductionObjs;
auto simdOp =
genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, simdArgs);
simdOp.setComposite(/*val=*/true);
@@ -3589,8 +3772,9 @@ static mlir::omp::DistributeOp genCompositeDistributeSimd(
mlir::omp::SimdOperands simdClauseOps;
llvm::SmallVector<const semantics::Symbol *> simdReductionSyms;
+ llvm::SmallVector<const Object *> simdReductionObjs;
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps,
- simdReductionSyms);
+ simdReductionSyms, simdReductionObjs, symTable);
DataSharingProcessor distributeItemDSP(
converter, semaCtx, distributeItem->clauses, eval,
@@ -3625,6 +3809,7 @@ static mlir::omp::DistributeOp genCompositeDistributeSimd(
simdArgs.priv.vars = simdClauseOps.privateVars;
simdArgs.reduction.syms = simdReductionSyms;
simdArgs.reduction.vars = simdClauseOps.reductionVars;
+ simdArgs.reduction.objs = simdReductionObjs;
auto simdOp =
genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, simdArgs);
simdOp.setComposite(/*val=*/true);
@@ -3650,17 +3835,21 @@ static mlir::omp::WsloopOp genCompositeDoSimd(
// values for array/box reduction variables, enabling genLoopVars()'s
// IRMapping to correctly chain the inner wrapper's operands to the outer
// wrapper's block arguments.
- llvm::DenseMap<const semantics::Symbol *, mlir::Value> reductionVarCache;
+ ReductionValueCache reductionVarCache;
mlir::omp::WsloopOperands wsloopClauseOps;
llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
+ llvm::SmallVector<const Object *> wsloopReductionObjs;
genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
- wsloopClauseOps, wsloopReductionSyms, &reductionVarCache);
+ wsloopClauseOps, wsloopReductionSyms, wsloopReductionObjs,
+ symTable, &reductionVarCache);
mlir::omp::SimdOperands simdClauseOps;
llvm::SmallVector<const semantics::Symbol *> simdReductionSyms;
+ llvm::SmallVector<const Object *> simdReductionObjs;
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps,
- simdReductionSyms, &reductionVarCache);
+ simdReductionSyms, simdReductionObjs, symTable,
+ &reductionVarCache);
DataSharingProcessor wsloopItemDSP(
converter, semaCtx, doItem->clauses, eval,
@@ -3688,6 +3877,7 @@ static mlir::omp::WsloopOp genCompositeDoSimd(
wsloopArgs.priv.vars = wsloopClauseOps.privateVars;
wsloopArgs.reduction.syms = wsloopReductionSyms;
wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
+ wsloopArgs.reduction.objs = wsloopReductionObjs;
auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
converter, loc, wsloopClauseOps, wsloopArgs);
wsloopOp.setComposite(/*val=*/true);
@@ -3697,6 +3887,7 @@ static mlir::omp::WsloopOp genCompositeDoSimd(
simdArgs.priv.vars = simdClauseOps.privateVars;
simdArgs.reduction.syms = simdReductionSyms;
simdArgs.reduction.vars = simdClauseOps.reductionVars;
+ simdArgs.reduction.objs = simdReductionObjs;
auto simdOp =
genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps, simdArgs);
simdOp.setComposite(/*val=*/true);
diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index d5387f7a59118..250d1aaeadf6c 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -13,6 +13,7 @@
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/Support/PrivateReductionUtils.h"
@@ -38,6 +39,13 @@ namespace Fortran {
namespace lower {
namespace omp {
+static bool isReductionObjectExpression(const Object *object) {
+ if (!object || !object->ref())
+ return false;
+ const SomeExpr &expr = *object->ref();
+ return evaluate::IsArrayElement(expr);
+}
+
// explicit template declarations
template bool ReductionProcessor::processReductionArguments<
mlir::omp::DeclareReductionOp, omp::clause::ReductionOperatorList>(
@@ -47,7 +55,8 @@ template bool ReductionProcessor::processReductionArguments<
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value> *reductionVarCache);
+ const llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symMap, ReductionValueCache *reductionVarCache);
template bool ReductionProcessor::processReductionArguments<
fir::DeclareReductionOp, llvm::SmallVector<fir::ReduceOperationEnum>>(
@@ -57,7 +66,8 @@ template bool ReductionProcessor::processReductionArguments<
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value> *reductionVarCache);
+ const llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symMap, ReductionValueCache *reductionVarCache);
template mlir::omp::DeclareReductionOp
ReductionProcessor::createDeclareReduction<mlir::omp::DeclareReductionOp>(
@@ -661,7 +671,8 @@ bool ReductionProcessor::processReductionArguments(
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value> *reductionVarCache) {
+ const llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symMap, ReductionValueCache *reductionVarCache) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
if constexpr (std::is_same_v<RedOperatorListTy,
@@ -703,80 +714,116 @@ bool ReductionProcessor::processReductionArguments(
builder.getRegion().getParentOfType<fir::DoConcurrentOp>());
}
- for (const semantics::Symbol *symbol : reductionSymbols) {
+ assert((reductionObjs.empty() ||
+ reductionSymbols.size() == reductionObjs.size()) &&
+ "mismatched reduction symbol and object lists");
+
+ for (unsigned i = 0; i < reductionSymbols.size(); ++i) {
+ const Object *object = reductionObjs.empty() ? nullptr : reductionObjs[i];
+ const semantics::Symbol *symbol =
+ object ? object->sym() : reductionSymbols[i];
+ const SomeExpr *expr = object && object->ref() ? &*object->ref() : nullptr;
+ const bool isObjectExpr = isReductionObjectExpression(object);
+
// If a cached reduction variable exists for this symbol, reuse it.
// This ensures that composite constructs (e.g. DO SIMD) where both
// the outer wrapper (wsloop) and inner wrapper (simd) process the same
// reduction clause share the same SSA value, enabling genLoopVars()'s
// IRMapping to correctly remap inner wrapper operands to outer wrapper
- // block arguments.
+ // block arguments. If an Expr is present for the symbol this is used,
+ // otherwise the symbol is used. This ensures that Expressions such as
+ // Array Elements are correctly represented when lowered.
if (reductionVarCache) {
- auto it = reductionVarCache->find(symbol);
- if (it != reductionVarCache->end()) {
+ if (isObjectExpr) {
+ auto it = reductionVarCache->exprCache.find(expr);
+ if (it != reductionVarCache->exprCache.end()) {
+ reductionVars.push_back(it->second);
+ reduceVarByRef.push_back(doReductionByRef(it->second));
+ continue;
+ }
+ } else if (auto it = reductionVarCache->symbolCache.find(symbol);
+ it != reductionVarCache->symbolCache.end()) {
reductionVars.push_back(it->second);
reduceVarByRef.push_back(doReductionByRef(it->second));
continue;
}
}
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
-
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
-
- mlir::Type eleType;
- auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
- if (refType)
- eleType = refType.getEleTy();
- else
- eleType = symVal.getType();
-
- // all arrays must be boxed so that we have convenient access to all the
- // information needed to iterate over the array
- if (mlir::isa<fir::SequenceType>(eleType)) {
- // For Host associated symbols, use `SymbolBox` instead
- lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol);
- hlfir::Entity entity{symBox.getAddr()};
- entity = genVariableBox(currentLocation, builder, entity);
- mlir::Value box = entity.getBase();
-
- // Always pass the box by reference so that the OpenMP dialect
- // verifiers don't need to know anything about fir.box
- auto alloca =
- fir::AllocaOp::create(builder, currentLocation, box.getType());
- fir::StoreOp::create(builder, currentLocation, box, alloca);
-
- symVal = alloca;
- } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
- // boxed arrays are passed as values not by reference. Unfortunately,
- // we can't pass a box by value to omp.redution_declare, so turn it
- // into a reference
- auto oldIP = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(builder.getAllocaBlock());
- auto alloca =
- fir::AllocaOp::create(builder, currentLocation, symVal.getType());
- builder.restoreInsertionPoint(oldIP);
- fir::StoreOp::create(builder, currentLocation, symVal, alloca);
- symVal = alloca;
- }
+ mlir::Value reductionVal;
+ mlir::Type refTy;
- // this isn't the same as the by-val and by-ref passing later in the
- // pipeline. Both styles assume that the variable is a reference at
- // this point
- assert(fir::isa_ref_type(symVal.getType()) &&
- "reduction input var is passed by reference");
- mlir::Type elementType = fir::dyn_cast_ptrEleTy(symVal.getType());
- const bool symIsVolatile = fir::isa_volatile_type(symVal.getType());
- mlir::Type refTy = fir::ReferenceType::get(elementType, symIsVolatile);
+ if (isObjectExpr) {
+ StatementContext stmtCtx;
+ hlfir::EntityWithAttributes entity = convertExprToHLFIR(
+ converter.getCurrentLocation(), converter, *expr, symMap, stmtCtx);
+ reductionVal = entity.getBase();
+ refTy = reductionVal.getType();
+ } else {
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+
+ mlir::Type eleType;
+ auto refType =
+ mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
+ if (refType)
+ eleType = refType.getEleTy();
+ else
+ eleType = symVal.getType();
+
+ // all arrays must be boxed so that we have convenient access to all the
+ // information needed to iterate over the array
+ if (mlir::isa<fir::SequenceType>(eleType)) {
+ // For Host associated symbols, use `SymbolBox` instead
+ lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol);
+ hlfir::Entity entity{symBox.getAddr()};
+ entity = genVariableBox(currentLocation, builder, entity);
+ mlir::Value box = entity.getBase();
+
+ // Always pass the box by reference so that the OpenMP dialect
+ // verifiers don't need to know anything about fir.box
+ auto alloca =
+ fir::AllocaOp::create(builder, currentLocation, box.getType());
+ fir::StoreOp::create(builder, currentLocation, box, alloca);
+
+ symVal = alloca;
+ } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
+ // boxed arrays are passed as values not by reference. Unfortunately,
+ // we can't pass a box by value to omp.redution_declare, so turn it
+ // into a reference
+ auto oldIP = builder.saveInsertionPoint();
+ builder.setInsertionPointToStart(builder.getAllocaBlock());
+ auto alloca =
+ fir::AllocaOp::create(builder, currentLocation, symVal.getType());
+ builder.restoreInsertionPoint(oldIP);
+ fir::StoreOp::create(builder, currentLocation, symVal, alloca);
+ symVal = alloca;
+ }
+ // this isn't the same as the by-val and by-ref passing later in the
+ // pipeline. Both styles assume that the variable is a reference at
+ // this point
+ assert(fir::isa_ref_type(symVal.getType()) &&
+ "reduction input var is passed by reference");
+ mlir::Type elementType = fir::dyn_cast_ptrEleTy(symVal.getType());
+ const bool symIsVolatile = fir::isa_volatile_type(symVal.getType());
+ refTy = fir::ReferenceType::get(elementType, symIsVolatile);
+ reductionVal = symVal;
+ }
reductionVars.push_back(
- builder.createConvert(currentLocation, refTy, symVal));
+ builder.createConvert(currentLocation, refTy, reductionVal));
reduceVarByRef.push_back(doReductionByRef(reductionVars.back()));
// Cache the final SSA value for this symbol so that subsequent calls
// (e.g. for the inner wrapper in a composite construct) reuse it.
- if (reductionVarCache)
- reductionVarCache->try_emplace(symbol, reductionVars.back());
+ if (reductionVarCache) {
+ if (isObjectExpr)
+ reductionVarCache->exprCache.try_emplace(expr, reductionVars.back());
+ else
+ reductionVarCache->symbolCache.try_emplace(symbol,
+ reductionVars.back());
+ }
}
unsigned idx = 0;
diff --git a/flang/test/Lower/OpenMP/reduction-array-element.f90 b/flang/test/Lower/OpenMP/reduction-array-element.f90
new file mode 100644
index 0000000000000..35c2bc86f3100
--- /dev/null
+++ b/flang/test/Lower/OpenMP/reduction-array-element.f90
@@ -0,0 +1,114 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %s -o - | FileCheck %s --implicit-check-not=add_reduction_byref_box
+
+subroutine reduction_literal(a, n)
+ integer :: a(4), n
+!$omp parallel do reduction(+: a(2))
+ do i = 1, n
+ a(2) = a(2) + i
+ end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPreduction_literal
+! CHECK: omp.wsloop {{.*}} reduction(@add_reduction_i32 {{.*}} : !fir.ref<i32>) {
+! CHECK: hlfir.declare %arg{{[0-9]+}} {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: fir.load %{{[0-9]+}}#0 : !fir.ref<i32>
+! CHECK: hlfir.assign {{.*}} to %{{[0-9]+}}#0 : i32, !fir.ref<i32>
+
+subroutine reduction_multiple(a, n)
+ integer :: a(4), n
+!$omp parallel do reduction(+: a(2), a(3))
+ do i = 1, n
+ a(2) = a(2) + i
+ a(3) = a(3) + i
+ end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPreduction_multiple
+! CHECK: omp.wsloop {{.*}} reduction(@add_reduction_i32 {{.*}}, @add_reduction_i32 {{.*}} : !fir.ref<i32>, !fir.ref<i32>) {
+! CHECK: hlfir.declare %arg{{[0-9]+}} {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: hlfir.declare %arg{{[0-9]+}} {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: hlfir.assign {{.*}} to %{{[0-9]+}}#0 : i32, !fir.ref<i32>
+! CHECK: hlfir.assign {{.*}} to %{{[0-9]+}}#0 : i32, !fir.ref<i32>
+
+subroutine reduction_arrays(a, b, n)
+ integer :: a(4), b(4), n
+!$omp parallel do reduction(+: a(2), b(2))
+ do i = 1, n
+ a(2) = a(2) + b(2) + i
+ end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPreduction_arrays
+! CHECK: omp.wsloop {{.*}} reduction(@add_reduction_i32 {{.*}}, @add_reduction_i32 {{.*}} : !fir.ref<i32>, !fir.ref<i32>) {
+! CHECK: hlfir.declare %arg{{[0-9]+}} {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: hlfir.declare %arg{{[0-9]+}} {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+subroutine reduction_variable(a, n, j)
+ integer :: a(4), n, j
+!$omp parallel do reduction(+: a(j))
+ do i = 1, n
+ a(j) = a(j) + i
+ end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPreduction_variable
+! CHECK: omp.wsloop {{.*}} reduction(@add_reduction_i32 {{.*}} : !fir.ref<i32>) {
+! CHECK: hlfir.declare %arg{{[0-9]+}} {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: hlfir.assign {{.*}} to %{{[0-9]+}}#0 : i32, !fir.ref<i32>
+
+subroutine reduction_do_simd(a, n)
+ integer :: a(4), n
+!$omp parallel do simd reduction(+: a(2))
+ do i = 1, n
+ a(2) = a(2) + i
+ end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPreduction_do_simd
+! CHECK: omp.wsloop reduction(@add_reduction_i32 {{.*}} -> [[WSARG:%arg[0-9]+]] : !fir.ref<i32>) {
+! CHECK: omp.simd {{.*}} reduction(@add_reduction_i32 [[WSARG]] -> [[SIMDARG:%arg[0-9]+]] : !fir.ref<i32>) {
+! CHECK: hlfir.declare [[SIMDARG]] {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: hlfir.assign {{.*}} to %{{[0-9]+}}#0 : i32, !fir.ref<i32>
+
+subroutine task_reduction_element(a)
+ integer :: a(4)
+!$omp taskgroup task_reduction(+: a(2))
+!$omp task in_reduction(+: a(2))
+ a(2) = a(2) + 1
+!$omp end task
+!$omp end taskgroup
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtask_reduction_element
+! CHECK: omp.taskgroup task_reduction(@add_reduction_i32 {{.*}} -> [[TGARG:%arg[0-9]+]] : !fir.ref<i32>) {
+! CHECK: [[TGDECL:%[0-9]+]]:2 = hlfir.declare [[TGARG]] {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: omp.task in_reduction(@add_reduction_i32 [[TGDECL]]#0 -> [[TASKARG:%arg[0-9]+]] : !fir.ref<i32>)
+! CHECK: [[TASKDECL:%[0-9]+]]:2 = hlfir.declare [[TASKARG]] {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: fir.load [[TASKDECL]]#0 : !fir.ref<i32>
+! CHECK: hlfir.assign {{.*}} to [[TASKDECL]]#0 : i32, !fir.ref<i32>
+
+subroutine taskloop_in_reduction_element(a, n)
+ integer :: a(4), n
+!$omp taskloop in_reduction(+: a(2))
+ do i = 1, n
+ a(2) = a(2) + i
+ end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtaskloop_in_reduction_element
+! CHECK: omp.taskloop.context in_reduction(@add_reduction_i32 {{.*}} -> [[TLARG:%arg[0-9]+]] : !fir.ref<i32>)
+! CHECK: hlfir.declare [[TLARG]] {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: hlfir.assign {{.*}} to %{{[0-9]+}}#0 : i32, !fir.ref<i32>
+
+subroutine taskloop_reduction_element(a, n)
+ integer :: a(4), n
+!$omp taskloop reduction(+: a(2))
+ do i = 1, n
+ a(2) = a(2) + i
+ end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtaskloop_reduction_element
+! CHECK: omp.taskloop.context {{.*}} reduction(@add_reduction_i32 {{.*}} -> [[TLRARG:%arg[0-9]+]] : !fir.ref<i32>)
+! CHECK: hlfir.declare [[TLRARG]] {uniq_name = "omp.reduction.element"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: hlfir.assign {{.*}} to %{{[0-9]+}}#0 : i32, !fir.ref<i32>
More information about the flang-commits
mailing list