[flang-commits] [flang] [Flang][OpenMP] Correct ArrayElements in Reduction Clause (PR #196094)
via flang-commits
flang-commits at lists.llvm.org
Wed May 6 08:22:09 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Jack Styles (Stylie777)
<details>
<summary>Changes</summary>
Currently, when an ArrayElement is used within a Reduction clause, it will be lowered with the reduction referencing the box containing the array, not just the element.
To address this, adjust Flang lowering to track expressions alongside symbol to ensure that the Array Element context is not lost and considered when lowering a reduction with Array Element. This ensures that, when represented in HLFIR, it will be just the element's type, rather than the full array.
Currently this excludes DO CONCURRENT as it excludes Array Elements, and is limited to Array Elements but there are options to expand this into Array Sections in the future.
Assisted-by: Codex
---
Patch is 70.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/196094.diff
9 Files Affected:
- (modified) flang/include/flang/Lower/Support/ReductionProcessor.h (+12-7)
- (modified) flang/include/flang/Support/OpenMP-utils.h (+4-1)
- (modified) flang/lib/Lower/Bridge.cpp (+3-1)
- (modified) flang/lib/Lower/ConvertExprToHLFIR.cpp (+27-4)
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+24-11)
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+11-4)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+258-67)
- (modified) flang/lib/Lower/Support/ReductionProcessor.cpp (+105-58)
- (added) flang/test/Lower/OpenMP/reduction-array-element.f90 (+114)
``````````diff
diff --git a/flang/include/flang/Lower/Support/ReductionProcessor.h b/flang/include/flang/Lower/Support/ReductionProcessor.h
index 0b4a692827a79..a1dab8fbc4d5e 100644
--- a/flang/include/flang/Lower/Support/ReductionProcessor.h
+++ b/flang/include/flang/Lower/Support/ReductionProcessor.h
@@ -38,6 +38,11 @@ namespace Fortran {
namespace lower {
namespace omp {
+struct ReductionValueCache {
+ llvm::DenseMap<const semantics::Symbol *, mlir::Value> symbolCache;
+ lower::ExprToValueMap exprCache;
+};
+
class ReductionProcessor {
public:
// ompOrig: mold/original variable
@@ -145,11 +150,11 @@ class ReductionProcessor {
/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
/// \param [in,out] reductionVarCache - optional cache mapping reduction
- /// symbols to their SSA values. When provided, array/box reduction
- /// variables that have already been allocated will be reused instead of
- /// creating new allocas. This ensures that nested composite wrappers
- /// (e.g. wsloop and simd in DO SIMD) share the same SSA values, allowing
- /// the genLoopVars() mapper to correctly remap inner wrapper operands.
+ /// objects to their SSA values. Scalar array elements are keyed by
+ /// expression, while whole-symbol reductions are keyed by symbol. This
+ /// ensures that nested composite wrappers (e.g. wsloop and simd in DO SIMD)
+ /// share the same SSA values without conflating distinct element
+ /// expressions of the same base symbol.
template <typename OpType, typename RedOperatorListTy>
static bool processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
@@ -158,8 +163,8 @@ class ReductionProcessor {
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value>
- *reductionVarCache = nullptr);
+ const llvm::SmallVectorImpl<const Object *> &reductionObjs,
+ lower::SymMap &symMap, ReductionValueCache *reductionVarCache = nullptr);
};
template <typename FloatOp, typename IntegerOp>
diff --git a/flang/include/flang/Support/OpenMP-utils.h b/flang/include/flang/Support/OpenMP-utils.h
index 6d9db2b682c50..435215b887de9 100644
--- a/flang/include/flang/Support/OpenMP-utils.h
+++ b/flang/include/flang/Support/OpenMP-utils.h
@@ -9,6 +9,7 @@
#ifndef FORTRAN_SUPPORT_OPENMP_UTILS_H_
#define FORTRAN_SUPPORT_OPENMP_UTILS_H_
+#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Semantics/symbol.h"
#include "mlir/IR/Builders.h"
@@ -22,12 +23,14 @@ namespace Fortran::common::openmp {
struct EntryBlockArgsEntry {
llvm::ArrayRef<const Fortran::semantics::Symbol *> syms;
llvm::ArrayRef<mlir::Value> vars;
+ llvm::ArrayRef<const Fortran::lower::omp::Object *> objs;
bool isValid() const {
// This check allows specifying a smaller number of symbols than values
// because in some case cases a single symbol generates multiple block
// arguments.
- return syms.size() <= vars.size();
+ return syms.size() <= vars.size() &&
+ (objs.empty() || objs.size() == syms.size());
}
};
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index c5709e1cd94d4..42e14d6a4dea3 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -128,6 +128,7 @@ struct IncrementLoopInfo {
llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> reduceSymList;
+ llvm::SmallVector<const Fortran::lower::omp::Object *> reduceObjList;
llvm::SmallVector<fir::ReduceOperationEnum> reduceOperatorList;
llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
mlir::Value loopVariable = nullptr;
@@ -2413,7 +2414,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::lower::omp::ReductionProcessor rp;
bool result = rp.processReductionArguments<fir::DeclareReductionOp>(
toLocation(), *this, info.reduceOperatorList, reduceVars,
- reduceVarByRef, reductionDeclSymbols, info.reduceSymList);
+ reduceVarByRef, reductionDeclSymbols, info.reduceSymList,
+ info.reduceObjList, getSymbolMap());
if (!result)
TODO(toLocation(), "Lowering unrecognised reduction type");
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index a57fce53c0ca5..12188706575e0 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1592,6 +1592,29 @@ static bool hasDeferredCharacterLength(const Fortran::semantics::Symbol &sym) {
type->characterTypeSpec().length().isDeferred();
}
+static mlir::Value
+findOverriddenExprValue(const Fortran::lower::ExprToValueMap &map,
+ const Fortran::lower::SomeExpr &expr) {
+ if (auto match = map.find(&expr); match != map.end())
+ return match->second;
+
+ if (!Fortran::evaluate::IsArrayElement(expr))
+ return {};
+
+ for (auto [key, value] : map) {
+ if (key == llvm::DenseMapInfo<
+ const Fortran::lower::SomeExpr *>::getEmptyKey() ||
+ key == llvm::DenseMapInfo<
+ const Fortran::lower::SomeExpr *>::getTombstoneKey())
+ continue;
+ if (Fortran::evaluate::IsArrayElement(*key) &&
+ key->AsFortran() == expr.AsFortran())
+ return value;
+ }
+
+ return {};
+}
+
/// Lower Expr to HLFIR.
class HlfirBuilder {
public:
@@ -1605,12 +1628,12 @@ class HlfirBuilder {
if (const Fortran::lower::ExprToValueMap *map =
getConverter().getExprOverrides()) {
if constexpr (std::is_same_v<T, Fortran::evaluate::SomeType>) {
- if (auto match = map->find(&expr); match != map->end())
- return hlfir::EntityWithAttributes{match->second};
+ if (mlir::Value value = findOverriddenExprValue(*map, expr))
+ return hlfir::EntityWithAttributes{value};
} else {
Fortran::lower::SomeExpr someExpr = toEvExpr(expr);
- if (auto match = map->find(&someExpr); match != map->end())
- return hlfir::EntityWithAttributes{match->second};
+ if (mlir::Value value = findOverriddenExprValue(*map, someExpr))
+ return hlfir::EntityWithAttributes{value};
}
}
return Fortran::common::visit([&](const auto &x) { return gen(x); },
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 1c39e90a922cf..0d317fd6fa496 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1498,31 +1498,36 @@ bool ClauseProcessor::processIf(
template <typename T>
void collectReductionSyms(
const T &reduction,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
+ llvm::SmallVectorImpl<const Object *> &reductionObjs) {
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
for (const Object &object : objectList) {
const semantics::Symbol *symbol = object.sym();
reductionSyms.push_back(symbol);
+ reductionObjs.push_back(&object);
}
}
bool ClauseProcessor::processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable, ReductionValueCache *reductionVarCache) const {
return findRepeatableClause<omp::clause::InReduction>(
[&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> inReductionVars;
llvm::SmallVector<bool> inReduceVarByRef;
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
- collectReductionSyms(clause, inReductionSyms);
+ llvm::SmallVector<const Object *> inReductionObjs;
+ collectReductionSyms(clause, inReductionSyms, inReductionObjs);
ReductionProcessor rp;
if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
currentLocation, converter,
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
inReductionVars, inReduceVarByRef, inReductionDeclSymbols,
- inReductionSyms))
+ inReductionSyms, inReductionObjs, symTable, reductionVarCache))
TODO(currentLocation, "Lowering unrecognised reduction type");
// Copy local lists into the output.
@@ -1532,6 +1537,7 @@ bool ClauseProcessor::processInReduction(
llvm::copy(inReductionDeclSymbols,
std::back_inserter(result.inReductionSyms));
llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
+ llvm::copy(inReductionObjs, std::back_inserter(outReductionObjs));
});
}
@@ -2024,15 +2030,16 @@ bool ClauseProcessor::processNontemporal(
bool ClauseProcessor::processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value> *reductionVarCache)
- const {
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable, ReductionValueCache *reductionVarCache) const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> reductionVars;
llvm::SmallVector<bool> reduceVarByRef;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
- collectReductionSyms(clause, reductionSyms);
+ llvm::SmallVector<const Object *> reductionObjs;
+ collectReductionSyms(clause, reductionSyms, reductionObjs);
auto mod = std::get<std::optional<ReductionModifier>>(clause.t);
if (mod.has_value()) {
@@ -2049,7 +2056,7 @@ bool ClauseProcessor::processReduction(
currentLocation, converter,
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
reductionVars, reduceVarByRef, reductionDeclSymbols,
- reductionSyms, reductionVarCache))
+ reductionSyms, reductionObjs, symTable, reductionVarCache))
TODO(currentLocation, "Lowering unrecognised reduction type");
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
@@ -2057,26 +2064,31 @@ bool ClauseProcessor::processReduction(
llvm::copy(reductionDeclSymbols,
std::back_inserter(result.reductionSyms));
llvm::copy(reductionSyms, std::back_inserter(outReductionSyms));
+ llvm::copy(reductionObjs, std::back_inserter(outReductionObjs));
});
}
bool ClauseProcessor::processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable, ReductionValueCache *reductionVarCache) const {
return findRepeatableClause<omp::clause::TaskReduction>(
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> taskReductionVars;
llvm::SmallVector<bool> taskReduceVarByRef;
llvm::SmallVector<mlir::Attribute> taskReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
- collectReductionSyms(clause, taskReductionSyms);
+ llvm::SmallVector<const Object *> taskReductionObjs;
+ collectReductionSyms(clause, taskReductionSyms, taskReductionObjs);
ReductionProcessor rp;
if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
currentLocation, converter,
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
taskReductionVars, taskReduceVarByRef, taskReductionDeclSymbols,
- taskReductionSyms))
+ taskReductionSyms, taskReductionObjs, symTable,
+ reductionVarCache))
TODO(currentLocation, "Lowering unrecognised reduction type");
// Copy local lists into the output.
llvm::copy(taskReductionVars,
@@ -2086,6 +2098,7 @@ bool ClauseProcessor::processTaskReduction(
llvm::copy(taskReductionDeclSymbols,
std::back_inserter(result.taskReductionSyms));
llvm::copy(taskReductionSyms, std::back_inserter(outReductionSyms));
+ llvm::copy(taskReductionObjs, std::back_inserter(outReductionObjs));
});
}
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index acf1068efb987..b34b531355e1d 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -142,7 +142,10 @@ class ClauseProcessor {
mlir::omp::IfClauseOps &result) const;
bool processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable,
+ ReductionValueCache *reductionVarCache = nullptr) const;
bool processIsDevicePtr(
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
@@ -167,11 +170,15 @@ class ClauseProcessor {
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
- llvm::DenseMap<const semantics::Symbol *, mlir::Value>
- *reductionVarCache = nullptr) const;
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable,
+ ReductionValueCache *reductionVarCache = nullptr) const;
bool processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
+ llvm::SmallVectorImpl<const Object *> &outReductionObjs,
+ lower::SymMap &symTable,
+ ReductionValueCache *reductionVarCache = nullptr) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 88d28cf94b045..0484464460225 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -75,6 +75,20 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
mlir::Location loc);
namespace {
+static bool isReductionObjectExpression(const Object *object) {
+ if (!object || !object->ref())
+ return false;
+ const SomeExpr &expr = *object->ref();
+ return evaluate::IsArrayElement(expr);
+}
+
+static std::optional<const SomeExpr *>
+getReductionObjectExpr(const Object *object) {
+ if (!isReductionObjectExpression(object))
+ return std::nullopt;
+ return &object->ref().value();
+}
+
/// Structure holding information that is needed to pass host-evaluated
/// information to later lowering stages.
class HostEvalInfo {
@@ -319,22 +333,52 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
bindSingleMapLike(*sym, arg);
};
- auto bindPrivateLike = [&converter, &firOpBuilder](
+ llvm::SmallPtrSet<const semantics::Symbol *, 8> objectReductionSyms;
+ auto collectObjectReductionSyms =
+ [&objectReductionSyms](llvm::ArrayRef<const Object *> objs) {
+ for (const Object *obj : objs)
+ if (isReductionObjectExpression(obj))
+ objectReductionSyms.insert(&obj->sym()->GetUltimate());
+ };
+ collectObjectReductionSyms(args.inReduction.objs);
+ collectObjectReductionSyms(args.reduction.objs);
+ collectObjectReductionSyms(args.taskReduction.objs);
+
+ auto bindPrivateLike = [&converter, &firOpBuilder, &objectReductionSyms](
llvm::ArrayRef<const semantics::Symbol *> syms,
llvm::ArrayRef<mlir::Value> vars,
- llvm::ArrayRef<mlir::BlockArgument> args) {
+ llvm::ArrayRef<mlir::BlockArgument> args,
+ llvm::ArrayRef<const Object *> objs,
+ bool skipObjectReductionSyms = false) {
+ assert((objs.empty() || objs.size() == syms.size()) &&
+ "invalid object list for private-like clause");
llvm::SmallVector<const semantics::Symbol *> processedSyms;
- for (auto *sym : syms) {
+ llvm::SmallVector<const Object *> processedObjs;
+ for (auto [idx, sym] : llvm::enumerate(syms)) {
+ const Object *obj = objs.empty() ? nullptr : objs[idx];
if (const auto *commonDet =
sym->detailsIf<semantics::CommonBlockDetails>()) {
- llvm::transform(commonDet->objects(), std::back_inserter(processedSyms),
- [&](const auto &mem) { return &*mem; });
+ for (auto &mem : commonDet->objects()) {
+ processedSyms.push_back(&*mem);
+ processedObjs.push_back(obj);
+ }
} else {
processedSyms.push_back(sym);
+ processedObjs.push_back(obj);
}
}
- for (auto [sym, var, arg] : llvm::zip_equal(processedSyms, vars, args))
+ assert(processedSyms.size() == processedObjs.size());
+ for (auto [sym, var, arg, obj] :
+ llvm::zip_equal(processedSyms, vars, args, processedObjs)) {
+ bool skipBind =
+ isReductionObjectExpression(obj) ||
+ (obj && sym->Rank() > 0 && !fir::unwrapUntilSeqType(arg.getType())) ||
+ (skipObjectReductionSyms &&
+ objectReductionSyms.contains(&sym->GetUltimate()));
+ if (skipBind)
+ continue;
+
converter.bindSymbol(
*sym,
hlfir::translateToExtendedValue(
@@ -342,6 +386,7 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
/*contiguousHint=*/
evaluate::IsSimplyContiguous(*sym, converter.getFoldingContext()))
.first);
+ }
};
// Process in clause name alphabetical order to match block arguments order.
@@ -349,13 +394,14 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
// corresponding region, except for very specific cases handled separately.
bindMapLike(args.hasDeviceAddr.syms, op.getHasDeviceAddrBlockArgs());
bindPrivateLike(args.inReduction.syms, args.inReduction.vars,
- op.getInReductionBlockArgs());
+ op.getInReductionBlockArgs(), args.inReduction.objs);
bindMapLike(args.map.syms, op.getMapBlockArgs());
- bindPrivateLike(args.priv.syms, args.priv.vars, op.getPrivateBlockArgs());
+ bindPrivateLike(args.priv.syms, args.priv.vars, op.getPrivateBlockArgs(),
+ args.priv.objs, /*skipObjectReductionSyms=*/true);
bindPrivateLike(args.reduction.syms, args.reduction.vars,
- op.getReductionBlockArgs());
+ op.getReductionBlockArgs(), args.reduction.objs);
bindPrivateLike(args.taskReduction.syms, args.taskReduction.vars,
- op.getTaskReductionBlockArgs());
+ op.getTaskReductionBlockArgs(), args.taskReduction.objs);
bindMapLike(args.useDeviceAd...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/196094
More information about the flang-commits
mailing list