[flang-commits] [flang] e80604a - [flang][OpenMP] Support user-defined declare reduction with derived types (#184897)
via flang-commits
flang-commits at lists.llvm.org
Thu Mar 26 08:48:52 PDT 2026
Author: Matt
Date: 2026-03-26T15:48:47Z
New Revision: e80604a6418404934a47bb3bfc14b4a21c1de626
URL: https://github.com/llvm/llvm-project/commit/e80604a6418404934a47bb3bfc14b4a21c1de626
DIFF: https://github.com/llvm/llvm-project/commit/e80604a6418404934a47bb3bfc14b4a21c1de626.diff
LOG: [flang][OpenMP] Support user-defined declare reduction with derived types (#184897)
Fix lowering of `!$omp declare reduction` for intrinsic operators
applied
to user-defined derived types (e.g., `+` on `type(t)`). Previously, this
hit a TODO in `ReductionProcessor::getReductionInitValue` because the
code
tried to compute an init value for a non-predefined type, when it should
instead use the initializer region from the `DeclareReductionOp`.
This fixes the issue #176278: [Flang][OpenMP] Compilation error when
type-list in declare reduction directive is derived type name.
The root cause was a naming mismatch: `genOMP` for
`OpenMPDeclareReductionConstruct` used a raw operator string (e.g.,
"Add")
as the reduction name, while `processReductionArguments` at the use site
computed a canonical name via `getReductionName` (e.g.,
"add_reduction_byref_rec__QFTt"). The `lookupSymbol` in
`createDeclareReductionHelper` never found the already-created op, so it
fell through to `createDeclareReduction` which called
`getReductionInitValue`
with the derived type and hit the TODO.
The fix has three parts:
1. Consistent names: In `genOMP` for `OpenMPDeclareReductionConstruct`,
compute
the reduction name using the same `getReductionName` scheme that
`processReductionArguments` uses, so both sites produce identical symbol
names.
For intrinsic operators, this maps through `ReductionIdentifier` to get
the
canonical name. For user-defined named reductions, the raw symbol name
is used
directly, matching the existing custom-reduction lookup path.
2. Reuse reduction: In `processReductionArguments`, when an intrinsic
operator
reduction is requested, check whether a user-defined declare reduction
already
exists under that canonical name before attempting to create a new one.
If
found, reuse it. This avoids calling `createDeclareReduction` (and thus
`getReductionInitValue`) for types that have user-provided initializers.
3. Reference semantics: Change `doReductionByRef` to return true for
derived
types. Previously it returned false for both trivial and derived types,
treating
derived types as by-val. This is incorrect for user-defined combiners
that
operate on components via side-effects (e.g., `omp_out%x = omp_out%x +
omp_in%x`): the combiner mutates `omp_out` in place and doesn't produce
a
whole-struct value, so `convertExprToValue` returns the component type
(`i32`) rather than the struct type, causing a type mismatch in the
`omp.yield`. By-ref is the correct model: the combiner stores into the
lhs reference and yields it.
The combiner callback in `processReductionCombiner` is also updated to
handle the by-ref derived-type case: when the combiner result type
doesn't match the element type (as happens with component-level
assignments), the store is skipped since the assignment already wrote
into omp_out as a side-effect, and only the lhs reference is yielded.
Tests updates:
- Update declare-reduction-intrinsic-op.f90 from a negative test
(checking
for the TODO error) to a positive test checking the generated MLIR.
- Update omp-declare-reduction-derivedtype.f90 CHECK lines to match the
reference semantics fix: the `declare_reduction` now has type
`!fir.ref<...>`
with a `byref_element_type` attribute, an alloc region, a two-argument
init
region, and a combiner that stores into the lhs and yields the
reference. The function body checks for initme and mycombine are
unchanged in substance but use literal type names instead of a regex
capture to avoid greedy matching issues with nested angle brackets.
Remaining work: declare reduction without an initializer clause is not
yet
supported. I plan to address that subsequently.
Assisted-by: Claude Opus 4.6.
Note: Relied on LLM (Claude Opus 4.6) to help navigate the Flang APIs
and assist
with the corresponding boilerplate code & tests updates; in particular:
in order
to get the aforementioned consistent naming, in
`ReductionProcessor::getReductionName` I had to get rid of
`parser::DefinedOperator::EnumToString` and instead introduce
`getRedIdFromParserIntrOp` (which does the conversion manually; just to
make
sure I haven't missed anything: is there no existing conversion
function?
AFAICT, there is none, but I might've missed it). In any case, feedback
welcome!
---------
Co-authored-by: Matt P. Dziubinski <matt-p.dziubinski at hpe.com>
Added:
flang/test/Lower/OpenMP/declare-reduction-finalizer.f90
Modified:
flang/include/flang/Lower/OpenMP/Clauses.h
flang/include/flang/Lower/Support/ReductionProcessor.h
flang/lib/Lower/OpenMP/ClauseProcessor.cpp
flang/lib/Lower/OpenMP/OpenMP.cpp
flang/lib/Lower/Support/PrivateReductionUtils.cpp
flang/lib/Lower/Support/ReductionProcessor.cpp
flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
Removed:
################################################################################
diff --git a/flang/include/flang/Lower/OpenMP/Clauses.h b/flang/include/flang/Lower/OpenMP/Clauses.h
index a325e74327240..f334374280c73 100644
--- a/flang/include/flang/Lower/OpenMP/Clauses.h
+++ b/flang/include/flang/Lower/OpenMP/Clauses.h
@@ -329,6 +329,17 @@ using UsesAllocators = tomp::clause::UsesAllocatorsT<TypeTy, IdTy, ExprTy>;
using Weak = tomp::clause::WeakT<TypeTy, IdTy, ExprTy>;
using When = tomp::clause::WhenT<TypeTy, IdTy, ExprTy>;
using Write = tomp::clause::WriteT<TypeTy, IdTy, ExprTy>;
+
+DefinedOperator makeDefinedOperator(const parser::DefinedOperator &inp,
+ semantics::SemanticsContext &semaCtx);
+
+ProcedureDesignator
+makeProcedureDesignator(const parser::ProcedureDesignator &inp,
+ semantics::SemanticsContext &semaCtx);
+
+ReductionOperator
+makeReductionOperator(const parser::OmpReductionIdentifier &inp,
+ semantics::SemanticsContext &semaCtx);
} // namespace clause
using tomp::type::operator==;
diff --git a/flang/include/flang/Lower/Support/ReductionProcessor.h b/flang/include/flang/Lower/Support/ReductionProcessor.h
index bd0447360f089..bbc4879bbe352 100644
--- a/flang/include/flang/Lower/Support/ReductionProcessor.h
+++ b/flang/include/flang/Lower/Support/ReductionProcessor.h
@@ -40,9 +40,11 @@ namespace omp {
class ReductionProcessor {
public:
- using GenInitValueCBTy =
- std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Type type, mlir::Value ompOrig)>;
+ // ompOrig: mold/original variable
+ // ompPriv: private allocation (may be null for by-value reductions)
+ using GenInitValueCBTy = std::function<mlir::Value(
+ fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
+ mlir::Value ompOrig, mlir::Value ompPriv)>;
using GenCombinerCBTy = std::function<void(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
mlir::Value op1, mlir::Value op2, bool isByRef)>;
@@ -126,7 +128,8 @@ class ReductionProcessor {
static DeclareRedType createDeclareReductionHelper(
AbstractConverter &converter, llvm::StringRef reductionOpName,
mlir::Type type, mlir::Location loc, bool isByRef,
- GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);
+ GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB,
+ const semantics::Symbol *sym = nullptr);
/// Creates an OpenMP reduction declaration and inserts it into the provided
/// symbol table. The declaration has a constant initializer with the neutral
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 75df4163c8a55..eb4be1f6dc1bd 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -507,7 +507,8 @@ bool ClauseProcessor::processInitializer(
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Type type, mlir::Value ompOrig) {
+ mlir::Type type, mlir::Value moldArg,
+ mlir::Value privArg) {
lower::SymMapScope scope(symMap);
mlir::Value ompPrivVar;
const StylizedInstance &inst = clause->v.front();
@@ -515,9 +516,10 @@ bool ClauseProcessor::processInitializer(
for (const Object &object :
std::get<StylizedInstance::Variables>(inst.t)) {
mlir::Value addr;
- mlir::Type ompOrigType = ompOrig.getType();
+ std::string name = object.sym()->name().ToString();
+ mlir::Type moldArgType = moldArg.getType();
// Check for unsupported dynamic-length character reductions
- mlir::Type unwrappedType = fir::unwrapRefType(ompOrigType);
+ mlir::Type unwrappedType = fir::unwrapRefType(moldArgType);
if (mlir::isa<fir::BoxCharType>(unwrappedType)) {
TODO(loc, "OpenMP reduction allocation for dynamic length character");
}
@@ -527,18 +529,20 @@ bool ClauseProcessor::processInitializer(
"OpenMP reduction allocation for dynamic length character");
}
}
- // If ompOrig is already a reference, we can use it directly
- if (fir::isa_ref_type(ompOrigType)) {
- addr = ompOrig;
+ // For by-ref reductions, omp_priv maps to privArg (the private
+ // allocation) and omp_orig maps to moldArg (the original).
+ if (name == "omp_priv" && privArg) {
+ addr = privArg;
+ } else if (fir::isa_ref_type(moldArgType)) {
+ addr = moldArg;
} else {
- addr = builder.createTemporary(loc, ompOrigType);
- fir::StoreOp::create(builder, loc, ompOrig, addr);
+ addr = builder.createTemporary(loc, moldArgType);
+ fir::StoreOp::create(builder, loc, moldArg, addr);
}
fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(
builder.getContext(), *object.sym(), extraFlags);
- std::string name = object.sym()->name().ToString();
// Get length parameters for types that need them (e.g., characters).
// Note: DeclareOp requires exactly one type parameter for non-boxed
// characters, unlike EmboxOp which doesn't allow them for constant-len.
@@ -570,9 +574,6 @@ bool ClauseProcessor::processInitializer(
[&](const auto &expr) -> mlir::Value {
mlir::Value exprResult = fir::getBase(convertExprToValue(
loc, converter, initExpr, symMap, stmtCtx));
- // Conversion can either give a value or a refrence to a value,
- // we need to return the reduction type, so an optional load may
- // be generated.
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
exprResult.getType()))
if (ompPrivVar.getType() == refType)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f86f15921b05e..3da35eafa8190 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -18,6 +18,7 @@
#include "Decomposer.h"
#include "Utils.h"
#include "flang/Common/idioms.h"
+#include "flang/Evaluate/expression.h"
#include "flang/Evaluate/tools.h"
#include "flang/Evaluate/type.h"
#include "flang/Lower/Bridge.h"
@@ -3836,14 +3837,27 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx, const clause::Combiner &combiner) {
+ semantics::SemanticsContext &semaCtx, const clause::Combiner &combiner,
+ const parser::OmpStylizedInstance &parserInst) {
+ // Extract the typed assignment from the parser-level instance, if
+ // the combiner is an assignment statement (as opposed to a call).
+ const evaluate::Assignment *assign = nullptr;
+ const auto &instance =
+ std::get<parser::OmpStylizedInstance::Instance>(parserInst.t);
+ if (const auto *assignStmt =
+ std::get_if<parser::AssignmentStmt>(&instance.u)) {
+ if (auto *wrapper = assignStmt->typedAssignment.get())
+ if (wrapper->v)
+ assign = &*wrapper->v;
+ }
ReductionProcessor::GenCombinerCBTy genCombinerCB;
const StylizedInstance &inst = combiner.v.front();
semantics::SomeExpr evalExpr = std::get<StylizedInstance::Instance>(inst.t);
- genCombinerCB = [&, evalExpr](fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Type type, mlir::Value lhs,
- mlir::Value rhs, bool isByRef) {
+ genCombinerCB = [&, evalExpr, assign](fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Type type,
+ mlir::Value lhs, mlir::Value rhs,
+ bool isByRef) {
lower::SymMapScope scope(symTable);
mlir::Value ompOutVar;
for (const Object &object : std::get<StylizedInstance::Variables>(inst.t)) {
@@ -3878,6 +3892,44 @@ static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
symTable.addVariableDefinition(*object.sym(), declareOp);
}
+ // For derived types with a typed assignment available, use
+ // hlfir::AssignOp or user-defined assignment directly instead of
+ // trying to convert the expression to a value (which doesn't work
+ // for record types). Only take this path when the assignment RHS
+ // itself is a derived type -- i.e. the combiner assigns to the whole
+ // derived-type variable (e.g. omp_out = mycombine(omp_out, omp_in)).
+ // When the combiner assigns to a component (e.g. omp_out%x = ...),
+ // the RHS is a scalar intrinsic type and the existing convertExprToValue
+ // path handles it correctly.
+ bool rhsIsDerived =
+ assign && assign->rhs.GetType() &&
+ assign->rhs.GetType()->category() == common::TypeCategory::Derived;
+ if (rhsIsDerived && isByRef &&
+ mlir::isa<fir::RecordType>(fir::unwrapRefType(lhs.getType()))) {
+ lower::StatementContext stmtCtx;
+ hlfir::Entity lhsEntity{ompOutVar};
+ hlfir::Entity rhsEntity = lower::convertExprToHLFIR(
+ loc, converter, assign->rhs, symTable, stmtCtx);
+ common::visit(
+ common::visitors{
+ [&](const evaluate::Assignment::Intrinsic &) {
+ hlfir::AssignOp::create(builder, loc, rhsEntity, lhsEntity);
+ },
+ [&](const evaluate::ProcedureRef &procRef) {
+ lower::convertUserDefinedAssignmentToHLFIR(
+ loc, converter, procRef, lhsEntity, rhsEntity, symTable);
+ },
+ [&](const auto &) {
+ llvm_unreachable(
+ "Unexpected assignment type in reduction combiner");
+ },
+ },
+ assign->u);
+ stmtCtx.finalizeAndPop();
+ mlir::omp::YieldOp::create(builder, loc, lhs);
+ return;
+ }
+
lower::StatementContext stmtCtx;
mlir::Value result = common::visit(
common::visitors{
@@ -3885,6 +3937,10 @@ static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
convertCallToHLFIR(loc, converter, procRef, std::nullopt,
symTable, stmtCtx);
auto outVal = fir::LoadOp::create(builder, loc, ompOutVar);
+ if (isByRef) {
+ fir::StoreOp::create(builder, loc, outVal, lhs);
+ return mlir::Value{};
+ }
return outVal;
},
[&](const auto &expr) -> mlir::Value {
@@ -3899,12 +3955,35 @@ static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
if (expectedType == refType.getElementType())
exprResult = fir::LoadOp::create(builder, loc, exprResult);
}
+ // For component-level derived-type combiners (e.g.
+ // omp_out%x = omp_out%x + omp_in%x), the assignment was
+ // not performed during expression lowering since
+ // convertExprToValue only evaluates the RHS value.
+ // The result type won't match the reduction variable type.
+ // Use the typed assignment LHS to store to the correct
+ // component, then skip the whole-variable store.
+ if (isByRef &&
+ exprResult.getType() != fir::unwrapRefType(lhs.getType())) {
+ if (assign) {
+ lower::StatementContext assignCtx;
+ hlfir::Entity lhsEntity = lower::convertExprToHLFIR(
+ loc, converter, assign->lhs, symTable, assignCtx);
+ hlfir::AssignOp::create(builder, loc, exprResult, lhsEntity);
+ assignCtx.finalizeAndPop();
+ } else {
+ fir::StoreOp::create(builder, loc, exprResult, ompOutVar);
+ }
+ return mlir::Value{};
+ }
+ if (isByRef) {
+ fir::StoreOp::create(builder, loc, exprResult, lhs);
+ return mlir::Value{};
+ }
return exprResult;
}},
evalExpr.u);
stmtCtx.finalizeAndPop();
if (isByRef) {
- fir::StoreOp::create(builder, loc, result, lhs);
mlir::omp::YieldOp::create(builder, loc, lhs);
} else {
mlir::omp::YieldOp::create(builder, loc, result);
@@ -3997,41 +4076,83 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
const auto &identifier =
std::get<parser::OmpReductionIdentifier>(specifier.t);
- std::string reductionNameStr = Fortran::common::visit(
- common::visitors{
- [](const parser::ProcedureDesignator &pd) -> std::string {
- return std::get<parser::Name>(pd.u).ToString();
- },
- [](const parser::DefinedOperator &defOp) -> std::string {
- return Fortran::common::visit(
- common::visitors{
- [](const parser::DefinedOpName &opName) -> std::string {
- return opName.v.ToString();
- },
- [](parser::DefinedOperator::IntrinsicOperator intrOp)
- -> std::string {
- return std::string(
- parser::DefinedOperator::EnumToString(intrOp));
- },
- },
- defOp.u);
- },
- },
- identifier.u);
+ // Convert the parser-level reduction identifier to the clause-level
+ // representation, then use ReductionProcessor to derive the canonical name.
+ clause::ReductionOperator redOp =
+ clause::makeReductionOperator(identifier, semaCtx);
+
+ // Get the parser-level combiner expression so we can pass each
+ // parser::OmpStylizedInstance to processReductionCombiner.
+ // The combiner expression's instances correspond 1:1 to typeNameList entries.
+ const auto *combinerExpr = parser::omp::GetCombinerExpr(specifier);
+ assert(combinerExpr && "Expecting combiner expression");
+ auto parserInstIt = combinerExpr->v.begin();
for (const auto &typeSpec : typeNameList.v) {
(void)typeSpec; // Currently unused
+
+ assert(parserInstIt != combinerExpr->v.end() &&
+ "Mismatched combiner instance count");
+ const parser::OmpStylizedInstance &parserInst = *parserInstIt++;
+
mlir::Type reductionType = getReductionType(converter, specifier);
+ bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+ // Compute the canonical reduction name the same way
+ // processReductionArguments does.
+ std::string reductionNameStr = Fortran::common::visit(
+ common::visitors{
+ [&](const clause::DefinedOperator &defOp) -> std::string {
+ return Fortran::common::visit(
+ common::visitors{
+ [&](const clause::DefinedOperator::IntrinsicOperator
+ &intrOp) -> std::string {
+ ReductionProcessor::ReductionIdentifier redId =
+ ReductionProcessor::getReductionType(intrOp);
+ return ReductionProcessor::getReductionName(
+ redId, converter.getFirOpBuilder().getKindMap(),
+ reductionType, isByRef);
+ },
+ [&](const clause::DefinedOperator::DefinedOpName &opName)
+ -> std::string {
+ return opName.v.sym()->name().ToString();
+ },
+ },
+ defOp.u);
+ },
+ [&](const clause::ProcedureDesignator &pd) -> std::string {
+ return pd.v.sym()->name().ToString();
+ },
+ },
+ redOp.u);
+
ReductionProcessor::GenCombinerCBTy genCombinerCB =
- processReductionCombiner(converter, symTable, semaCtx, combiner);
+ processReductionCombiner(converter, symTable, semaCtx, combiner,
+ parserInst);
ReductionProcessor::GenInitValueCBTy genInitValueCB;
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processInitializer(symTable, genInitValueCB);
- bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+ mlir::Type redType =
+ isByRef
+ ? static_cast<mlir::Type>(fir::ReferenceType::get(reductionType))
+ : reductionType;
+
+ // Get the omp_out symbol from the combiner for finalization checks
+ // in populateByRefInitAndCleanupRegions.
+ const semantics::Symbol *reductionSym = nullptr;
+ const auto &declList =
+ std::get<std::list<parser::OmpStylizedDeclaration>>(parserInst.t);
+ for (const auto &decl : declList) {
+ const auto &name = std::get<parser::ObjectName>(decl.var.t);
+ if (name.ToString() == "omp_out") {
+ reductionSym = name.symbol;
+ break;
+ }
+ }
+
ReductionProcessor::createDeclareReductionHelper<
mlir::omp::DeclareReductionOp>(
- converter, reductionNameStr, reductionType,
- converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
+ converter, reductionNameStr, redType, converter.getCurrentLocation(),
+ isByRef, genCombinerCB, genInitValueCB, reductionSym);
}
}
diff --git a/flang/lib/Lower/Support/PrivateReductionUtils.cpp b/flang/lib/Lower/Support/PrivateReductionUtils.cpp
index aae433c023d01..aaf2069ec34bd 100644
--- a/flang/lib/Lower/Support/PrivateReductionUtils.cpp
+++ b/flang/lib/Lower/Support/PrivateReductionUtils.cpp
@@ -155,6 +155,20 @@ static void createCleanupRegion(Fortran::lower::AbstractConverter &converter,
return;
}
+ // Handle unboxed derived types that need finalization (e.g. types with
+ // FINAL subroutines). Embox the reference and call the runtime destroy.
+ if (fir::isa_derived(valTy) && mlir::isa<fir::ReferenceType>(argType)) {
+ mlir::Type boxTy = fir::BoxType::get(valTy);
+ mlir::Value box =
+ fir::EmboxOp::create(builder, loc, boxTy, block->getArgument(0));
+ fir::runtime::genDerivedTypeDestroy(builder, loc, box);
+ if (isDoConcurrent)
+ fir::YieldOp::create(builder, loc);
+ else
+ mlir::omp::YieldOp::create(builder, loc);
+ return;
+ }
+
typeError();
}
@@ -636,6 +650,18 @@ void PopulateInitAndCleanupRegionsHelper::initAndCleanupBoxchar(
void PopulateInitAndCleanupRegionsHelper::initAndCleanupUnboxedDerivedType(
bool needsInitialization) {
builder.setInsertionPointToStart(initBlock);
+ // For reductions with a user-provided init value, store it into the
+ // private variable. Insert after the init value's defining op to
+ // maintain SSA dominance (the init value was generated by the
+ // callback before populateByRefInitAndCleanupRegions was called).
+ if (scalarInitValue && isReduction(kind)) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ if (auto *defOp = scalarInitValue.getDefiningOp())
+ builder.setInsertionPointAfter(defOp);
+ else
+ builder.setInsertionPointToEnd(initBlock);
+ fir::StoreOp::create(builder, loc, scalarInitValue, allocatedPrivVarArg);
+ }
mlir::Type boxedTy = fir::BoxType::get(valType);
mlir::Value newBox =
fir::EmboxOp::create(builder, loc, boxedTy, allocatedPrivVarArg);
diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index e0cba4c512258..36bd1f1ef2397 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -502,7 +502,7 @@ template <typename OpType>
static void createReductionAllocAndInitRegions(
AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl,
ReductionProcessor::GenInitValueCBTy genInitValueCB, mlir::Type type,
- bool isByRef) {
+ bool isByRef, const Fortran::semantics::Symbol *sym) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); };
@@ -524,13 +524,16 @@ static void createReductionAllocAndInitRegions(
mlir::Type ty = fir::unwrapRefType(type);
builder.setInsertionPointToEnd(initBlock);
mlir::Value initValue =
- genInitValueCB(builder, loc, ty, initBlock->getArgument(0));
+ isByRef ? genInitValueCB(builder, loc, ty, initBlock->getArgument(0),
+ initBlock->getArgument(1))
+ : genInitValueCB(builder, loc, ty, initBlock->getArgument(0),
+ mlir::Value{});
if (isByRef) {
populateByRefInitAndCleanupRegions(
converter, loc, type, initValue, initBlock,
reductionDecl.getInitializerAllocArg(),
reductionDecl.getInitializerMoldArg(), reductionDecl.getCleanupRegion(),
- DeclOperationKind::Reduction, /*sym=*/nullptr,
+ DeclOperationKind::Reduction, sym,
/*cannotHaveLowerBounds=*/false,
/*isDoConcurrent*/ std::is_same_v<OpType, fir::DeclareReductionOp>);
}
@@ -559,7 +562,8 @@ template <typename DeclareRedType>
DeclareRedType ReductionProcessor::createDeclareReductionHelper(
AbstractConverter &converter, llvm::StringRef reductionOpName,
mlir::Type type, mlir::Location loc, bool isByRef,
- GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB) {
+ GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB,
+ const semantics::Symbol *sym) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::OpBuilder::InsertionGuard guard(builder);
mlir::ModuleOp module = builder.getModule();
@@ -593,7 +597,7 @@ DeclareRedType ReductionProcessor::createDeclareReductionHelper(
decl = DeclareRedType::create(modBuilder, loc, reductionOpName, type,
boxedTyAttr);
createReductionAllocAndInitRegions(converter, loc, decl, genInitValueCB, type,
- isByRef);
+ isByRef, sym);
builder.createBlock(&decl.getReductionRegion(),
decl.getReductionRegion().end(), {type, type},
{loc, loc});
@@ -622,7 +626,8 @@ OpType ReductionProcessor::createDeclareReduction(
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef) {
auto genInitValueCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Type type, mlir::Value val) {
+ mlir::Type type, mlir::Value /*moldArg*/,
+ mlir::Value /*privArg*/) {
mlir::Type ty = fir::unwrapRefType(type);
mlir::Value initValue = ReductionProcessor::getReductionInitValue(
loc, unwrapSeqOrBoxedType(ty), redId, builder);
@@ -642,11 +647,11 @@ OpType ReductionProcessor::createDeclareReduction(
bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) {
if (forceByrefReduction)
return true;
-
- if (!fir::isa_trivial(fir::unwrapRefType(reductionType)) &&
- !fir::isa_derived(fir::unwrapRefType(reductionType)))
+ // Non-trivial, non-derived types (e.g., boxes, arrays) must be by-ref.
+ // Derived types must also be by-ref because user-defined combiners
+ // operate on components via side-effects, not by producing a whole value.
+ if (!fir::isa_trivial(fir::unwrapRefType(reductionType)))
return true;
-
return false;
}
@@ -798,6 +803,16 @@ bool ReductionProcessor::processReductionArguments(
}
reductionName = getReductionName(redId, kindMap, redType, isByRef);
+ // If a user-defined declare reduction already exists for this
+ // operator+type, reuse it instead of generating a new one
+ // (which would fail for non-predefined types like derived types).
+ mlir::ModuleOp module = builder.getModule();
+ if (auto existingDecl = module.lookupSymbol<OpType>(reductionName)) {
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ builder.getContext(), existingDecl.getSymName()));
+ ++idx;
+ continue;
+ }
} else if (const auto *reductionIntrinsic =
std::get_if<omp::clause::ProcedureDesignator>(
&redOperator.u)) {
diff --git a/flang/test/Lower/OpenMP/declare-reduction-finalizer.f90 b/flang/test/Lower/OpenMP/declare-reduction-finalizer.f90
new file mode 100644
index 0000000000000..2ec34c446e793
--- /dev/null
+++ b/flang/test/Lower/OpenMP/declare-reduction-finalizer.f90
@@ -0,0 +1,86 @@
+! Test declare reduction with derived types that have FINAL subroutines
+! and non-trivial user-defined initializers, to verify that initialization
+! and finalization are generated correctly.
+!
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! ---------------------------------------------------------------------
+! Test 1: Simple derived type with finalizer and structure constructor init
+! ---------------------------------------------------------------------
+module m1
+ implicit none
+
+ type :: t
+ integer :: x = -999
+ contains
+ final :: cleanup
+ end type t
+
+contains
+
+ subroutine cleanup(this)
+ type(t), intent(inout) :: this
+ this%x = 0
+ end subroutine cleanup
+
+end module m1
+
+! CHECK-LABEL: omp.declare_reduction @plus_t{{.*}} : !fir.ref<{{.*}}>
+!
+! -- alloc region
+! CHECK: alloc {
+! CHECK: %[[ALLOCA:.*]] = fir.alloca
+! CHECK: omp.yield(%[[ALLOCA]] :
+!
+! -- init region: must store 100 (from initializer clause), not -999 (default)
+! CHECK: } init {
+! CHECK: ^bb0(%[[INIT_ARG0:.*]]: !fir.ref<{{.*}}>, %[[INIT_ARG1:.*]]: !fir.ref<{{.*}}>):
+! CHECK: %{{.*}}:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_orig"}
+! CHECK: %[[PRIV_DECL:.*]]:2 = hlfir.declare %[[INIT_ARG1]] {uniq_name = "omp_priv"}
+! CHECK: %[[INIT_ADDR:.*]] = fir.address_of(@_QQro._QMm1Tt.0)
+! CHECK: %[[INIT_DECL:.*]]:2 = hlfir.declare %[[INIT_ADDR]]
+! CHECK: %[[INIT_VAL:.*]] = fir.load %[[INIT_DECL]]#0
+! CHECK: fir.store %[[INIT_VAL]] to %[[INIT_ARG1]]
+! CHECK: omp.yield(%[[INIT_ARG1]] :
+!
+! -- combiner region
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[LHS:.*]]: !fir.ref<{{.*}}>, %[[RHS:.*]]: !fir.ref<{{.*}}>):
+! CHECK: %{{.*}}:2 = hlfir.declare %[[RHS]] {uniq_name = "omp_in"}
+! CHECK: %{{.*}}:2 = hlfir.declare %[[LHS]] {uniq_name = "omp_out"}
+! CHECK: hlfir.assign %{{.*}} to %{{.*}} : i32, !fir.ref<i32>
+! CHECK: omp.yield(%[[LHS]] :
+! -- cleanup region: calls runtime destroy (which dispatches to the finalizer)
+! CHECK: } cleanup {
+! CHECK: ^bb0(%[[CLEANUP_ARG:.*]]: !fir.ref<{{.*}}>):
+! CHECK: %[[BOX:.*]] = fir.embox %[[CLEANUP_ARG]]
+! CHECK: %[[CONV:.*]] = fir.convert %[[BOX]]
+! CHECK: fir.call @_FortranADestroy(%[[CONV]])
+! CHECK: omp.yield
+! CHECK: }
+!
+! TODO: Test declare reduction without an initializer clause to verify
+! the default constructor value (-999) is used. This requires support
+! for declare reduction without an initializer clause.
+
+! Verify the init value constant is 100 (from T(100)), not -999 (default)
+! CHECK: fir.global internal @_QQro._QMm1Tt.0 constant
+! CHECK: %[[C100:.*]] = arith.constant 100 : i32
+! CHECK: fir.insert_value %{{.*}}, %[[C100]], ["x",
+
+program test1
+ use m1
+ implicit none
+
+ type(t) :: a
+
+ !$omp declare reduction(plus_t:t: omp_out%x = omp_out%x + omp_in%x) &
+ !$omp& initializer(omp_priv = t(100))
+
+ a = t(200)
+
+ !$omp parallel reduction(plus_t:a)
+ a%x = a%x + 1
+ !$omp end parallel
+
+end program test1
diff --git a/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90 b/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
index 8b5051b63afd4..f519ddb6b5989 100644
--- a/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
+++ b/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
@@ -1,10 +1,9 @@
-! RUN: not %flang_fc1 -emit-mlir -fopenmp %s -o - 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
program test
type t
integer :: x
end type t
- ! CHECK: not yet implemented: Reduction of some types is not supported
!$omp declare reduction(+:t: omp_out%x = omp_out%x + omp_in%x) initializer(omp_priv = t(0))
type(t) :: a
a = t(0)
@@ -12,3 +11,27 @@ program test
a%x = a%x + 1
!$omp end parallel
end program test
+
+! CHECK: omp.declare_reduction @add_reduction_byref_rec__QFTt :
+! CHECK: %[[ALLOCA:.*]] = fir.alloca [[TY:.*]]
+! CHECK: omp.yield(%[[ALLOCA]] : !fir.ref<[[TY]]>)
+! CHECK: } init {
+! CHECK: ^bb0(%[[INIT_ARG0:.*]]: !fir.ref<[[TY]]>, %[[INIT_ARG1:.*]]: !fir.ref<[[TY]]>):
+! CHECK: %{{.*}} = fir.embox %[[INIT_ARG1]]
+! CHECK: %{{.*}} = fir.embox %[[INIT_ARG0]]
+! CHECK: %{{.*}}:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_orig"}
+! CHECK: %{{.*}}:2 = hlfir.declare %[[INIT_ARG1]] {uniq_name = "omp_priv"}
+! CHECK: omp.yield(%[[INIT_ARG1]] : !fir.ref<[[TY]]>)
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<[[TY]]>, %[[ARG1:.*]]: !fir.ref<[[TY]]>):
+! CHECK: %[[OMP_IN:.*]]:2 = hlfir.declare %[[ARG1]] {uniq_name = "omp_in"}
+! CHECK: %[[OMP_OUT:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "omp_out"}
+! CHECK: %[[OUT_X:.*]] = hlfir.designate %[[OMP_OUT]]#0{"x"} : (!fir.ref<[[TY]]>) -> !fir.ref<i32>
+! CHECK: %[[OUT_X_VAL:.*]] = fir.load %[[OUT_X]] : !fir.ref<i32>
+! CHECK: %[[IN_X:.*]] = hlfir.designate %[[OMP_IN]]#0{"x"} : (!fir.ref<[[TY]]>) -> !fir.ref<i32>
+! CHECK: %[[IN_X_VAL:.*]] = fir.load %[[IN_X]] : !fir.ref<i32>
+! CHECK: %[[ADD:.*]] = arith.addi %[[OUT_X_VAL]], %[[IN_X_VAL]] : i32
+! CHECK: %[[OUT_X2:.*]] = hlfir.designate %[[OMP_OUT]]#0{"x"} : (!fir.ref<[[TY]]>) -> !fir.ref<i32>
+! CHECK: hlfir.assign %[[ADD]] to %[[OUT_X2]] : i32, !fir.ref<i32>
+! CHECK: omp.yield(%[[ARG0]] : !fir.ref<[[TY]]>)
+! CHECK: }
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
index 7e481a9264117..1fea2aee64f69 100644
--- a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
@@ -1,5 +1,5 @@
! This test checks lowering of OpenMP declare reduction Directive, with initialization
-! via a subroutine. This functionality is currently not implemented.
+! via a subroutine.
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
module maxtype_mod
@@ -41,35 +41,31 @@ function func(x, n, init)
end function func
end module maxtype_mod
-!CHECK: omp.declare_reduction @red_add_max : [[MAXTYPE:.*]] init {
-!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: [[MAXTYPE]]):
-!CHECK: %[[OMP_PRIV:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK: %[[OMP_ORIG:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]>
-!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<[[MAXTYPE]]>
-!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: omp.declare_reduction @red_add_max : !fir.ref<[[MAXTYPE:.*]]> {{.*}} alloc {
+!CHECK: %[[ALLOCA:.*]] = fir.alloca [[MAXTYPE:.*]]
+!CHECK: omp.yield(%[[ALLOCA]] : !fir.ref<[[MAXTYPE]]>)
+!CHECK: } init {
+!CHECK: ^bb0(%[[INIT_ARG0:.*]]: !fir.ref<[[MAXTYPE]]>, %[[INIT_ARG1:.*]]: !fir.ref<[[MAXTYPE]]>):
+!CHECK: %{{.*}} = fir.embox %[[INIT_ARG1]]
+!CHECK: %{{.*}} = fir.embox %[[INIT_ARG0]]
+!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[INIT_ARG1]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
-!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : [[MAXTYPE]])
+!CHECK: omp.yield(%[[INIT_ARG1]] : !fir.ref<[[MAXTYPE]]>)
!CHECK: } combiner {
-!CHECK: ^bb0(%[[LHS_ARG:.*]]: [[MAXTYPE]], %[[RHS_ARG:.*]]: [[MAXTYPE]]):
+!CHECK: ^bb0(%[[LHS_ARG:.*]]: !fir.ref<[[MAXTYPE]]>, %[[RHS_ARG:.*]]: !fir.ref<[[MAXTYPE]]>):
!CHECK: %[[RESULT:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = ".result"}
-!CHECK: %[[OMP_OUT:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK: %[[OMP_IN:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<[[MAXTYPE]]>
-!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<[[MAXTYPE]]>
-!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[OMP_IN:.*]]:2 = hlfir.declare %[[RHS_ARG]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[OMP_OUT:.*]]:2 = hlfir.declare %[[LHS_ARG]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %[[TMPRESULT:.*]]:2 = hlfir.declare %[[RESULT]] {uniq_name = ".tmp.func_result"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK: %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
+!CHECK: %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT]]#0, %[[OMP_IN]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
!CHECK: fir.save_result %[[COMBINE_RESULT]] to %[[TMPRESULT]]#0 : [[MAXTYPE]], !fir.ref<[[MAXTYPE]]>
!CHECK: %false = arith.constant false
!CHECK: %[[EXPRRESULT:.*]] = hlfir.as_expr %[[TMPRESULT]]#0 move %false : (!fir.ref<[[MAXTYPE]]>, i1) -> !hlfir.expr<[[MAXTYPE]]>
-!CHECK: %[[ASSOCIATE:.*]]:3 = hlfir.associate %[[EXPRRESULT]] {adapt.valuebyref} : (!hlfir.expr<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>, i1)
-!CHECK: %[[RESULT_VAL:.*]] = fir.load %[[ASSOCIATE]]#0 : !fir.ref<[[MAXTYPE]]>
-!CHECK: hlfir.end_associate %[[ASSOCIATE]]#1, %[[ASSOCIATE]]#2 : !fir.ref<[[MAXTYPE]]>, i1
-!CHECK: omp.yield(%[[RESULT_VAL]] : [[MAXTYPE]])
+!CHECK: hlfir.assign %[[EXPRRESULT]] to %[[OMP_OUT]]#0 : !hlfir.expr<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>
+!CHECK: hlfir.destroy %[[EXPRRESULT]] : !hlfir.expr<[[MAXTYPE]]>
+!CHECK: omp.yield(%[[LHS_ARG]] : !fir.ref<[[MAXTYPE]]>)
!CHECK: }
!CHECK: func.func @_QMmaxtype_modPinitme(%[[X_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "x"}, %[[N_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "n"}) {
More information about the flang-commits
mailing list