[flang-commits] [flang] [flang][OpenMP] Support user-defined declare reduction with derived types (PR #184897)
via flang-commits
flang-commits at lists.llvm.org
Thu Mar 5 13:59:42 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Matt (MattPD)
<details>
<summary>Changes</summary>
Fix lowering of `!$omp declare reduction` for intrinsic operators applied
to user-defined derived types (e.g., `+` on `type(t)`). Previously, this
hit a TODO in `ReductionProcessor::getReductionInitValue` because the code
tried to compute an init value for a non-predefined type, when it should
instead use the initializer region from the `DeclareReductionOp`.
The root cause was a naming mismatch: `genOMP` for
`OpenMPDeclareReductionConstruct` used a raw operator string (e.g., "Add")
as the reduction name, while `processReductionArguments` at the use site
computed a canonical name via `getReductionName` (e.g.,
"add_reduction_byref_rec__QFTt"). The `lookupSymbol` in
`createDeclareReductionHelper` never found the already-created op, so it
fell through to `createDeclareReduction` which called `getReductionInitValue`
with the derived type and hit the TODO.
The fix has three parts:
1. Consistent names: In `genOMP` for `OpenMPDeclareReductionConstruct`, compute
the reduction name using the same `getReductionName` scheme that
`processReductionArguments` uses, so both sites produce identical symbol names.
For intrinsic operators, this maps through `ReductionIdentifier` to get the
canonical name. For user-defined named reductions, the raw symbol name is used
directly, matching the existing custom-reduction lookup path.
2. Reuse reduction: In `processReductionArguments`, when an intrinsic operator
reduction is requested, check whether a user-defined declare reduction already
exists under that canonical name before attempting to create a new one. If
found, reuse it. This avoids calling `createDeclareReduction` (and thus
`getReductionInitValue`) for types that have user-provided initializers.
3. Reference semantics: Change `doReductionByRef` to return true for derived
types. Previously it returned false for both trivial and derived types, treating
derived types as by-val. This is incorrect for user-defined combiners that
operate on components via side-effects (e.g., `omp_out%x = omp_out%x +
omp_in%x`): the combiner mutates `omp_out` in place and doesn't produce a
whole-struct value, so `convertExprToValue` returns the component type
(`i32`) rather than the struct type, causing a type mismatch in the
`omp.yield`. By-ref is the correct model: the combiner stores into the
lhs reference and yields it.
The combiner callback in `processReductionCombiner` is also updated to
handle the by-ref derived-type case: when the combiner result type
doesn't match the element type (as happens with component-level
assignments), the store is skipped since the assignment already wrote
into omp_out as a side-effect, and only the lhs reference is yielded.
Tests updates:
- Update declare-reduction-intrinsic-op.f90 from a negative test (checking
for the TODO error) to a positive test checking the generated MLIR.
- Update omp-declare-reduction-derivedtype.f90 CHECK lines to match the
reference semantics fix: the `declare_reduction` now has type `!fir.ref<...>`
with a `byref_element_type` attribute, an alloc region, a two-argument init
region, and a combiner that stores into the lhs and yields the
reference. The function body checks for initme and mycombine are
unchanged in substance but use literal type names instead of a regex
capture to avoid greedy matching issues with nested angle brackets.
Remaining work: declare reduction without an initializer clause is not yet
supported. I plan to address that subsequently.
Note: Relied on LLM (Claude Opus 4.6) to help navigate the Flang APIs and assist
with the corresponding boilerplate code & tests updates; in particular: in order
to get the aforementioned consistent naming, in
`ReductionProcessor::getReductionName` I had to get rid of
`parser::DefinedOperator::EnumToString` and instead introduce
`getRedIdFromParserIntrOp` (which does the conversion manually; just to make
sure I haven't missed anything: is there no existing conversion function?
AFAICT, there is none, but I might've missed it). In any case, feedback welcome!
---
Full diff: https://github.com/llvm/llvm-project/pull/184897.diff
4 Files Affected:
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+72-24)
- (modified) flang/lib/Lower/Support/ReductionProcessor.cpp (+15-4)
- (modified) flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90 (+11-2)
- (modified) flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 (+20-21)
``````````diff
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index e2018add11206..91f23621fdc4a 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3864,7 +3864,13 @@ static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
evalExpr.u);
stmtCtx.finalizeAndPop();
if (isByRef) {
- fir::StoreOp::create(builder, loc, result, lhs);
+ // For user-defined combiners the assignment expression (e.g.
+ // "omp_out%x = omp_out%x + omp_in%x") already wrote into omp_out
+ // as a side-effect. We only need to yield the lhs reference.
+ // Only store result back if its type actually matches the element type.
+ mlir::Type eleTy = fir::unwrapRefType(lhs.getType());
+ if (result.getType() == eleTy)
+ fir::StoreOp::create(builder, loc, result, lhs);
mlir::omp::YieldOp::create(builder, loc, lhs);
} else {
mlir::omp::YieldOp::create(builder, loc, result);
@@ -3940,6 +3946,30 @@ appendCombiner(const parser::OpenMPDeclareReductionConstruct &construct,
llvm_unreachable("Expecting reduction combiner");
}
+// Map parser intrinsic operators to ReductionIdentifier.
+// Only operators valid for OpenMP reductions are mapped.
+static ReductionProcessor::ReductionIdentifier
+getRedIdFromParserIntrOp(parser::DefinedOperator::IntrinsicOperator intrOp) {
+ switch (intrOp) {
+ case parser::DefinedOperator::IntrinsicOperator::Add:
+ return ReductionProcessor::ReductionIdentifier::ADD;
+ case parser::DefinedOperator::IntrinsicOperator::Subtract:
+ return ReductionProcessor::ReductionIdentifier::SUBTRACT;
+ case parser::DefinedOperator::IntrinsicOperator::Multiply:
+ return ReductionProcessor::ReductionIdentifier::MULTIPLY;
+ case parser::DefinedOperator::IntrinsicOperator::AND:
+ return ReductionProcessor::ReductionIdentifier::AND;
+ case parser::DefinedOperator::IntrinsicOperator::OR:
+ return ReductionProcessor::ReductionIdentifier::OR;
+ case parser::DefinedOperator::IntrinsicOperator::EQV:
+ return ReductionProcessor::ReductionIdentifier::EQV;
+ case parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return ReductionProcessor::ReductionIdentifier::NEQV;
+ default:
+ llvm_unreachable("unexpected intrinsic operator for reduction");
+ }
+}
+
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
@@ -3957,28 +3987,6 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
const auto &identifier =
std::get<parser::OmpReductionIdentifier>(specifier.t);
- std::string reductionNameStr = Fortran::common::visit(
- common::visitors{
- [](const parser::ProcedureDesignator &pd) -> std::string {
- return std::get<parser::Name>(pd.u).ToString();
- },
- [](const parser::DefinedOperator &defOp) -> std::string {
- return Fortran::common::visit(
- common::visitors{
- [](const parser::DefinedOpName &opName) -> std::string {
- return opName.v.ToString();
- },
- [](parser::DefinedOperator::IntrinsicOperator intrOp)
- -> std::string {
- return std::string(
- parser::DefinedOperator::EnumToString(intrOp));
- },
- },
- defOp.u);
- },
- },
- identifier.u);
-
for (const auto &typeSpec : typeNameList.v) {
(void)typeSpec; // Currently unused
mlir::Type reductionType = getReductionType(converter, specifier);
@@ -3988,9 +3996,49 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processInitializer(symTable, genInitValueCB);
bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+
+ // Build the reduction name to match what processReductionArguments uses
+ // when looking up the DeclareReductionOp by name. For intrinsic operators,
+ // use getReductionName(redId, ...) to produce the canonical name (e.g.,
+ // "add_reduction_byref_rec_..."). For user-defined reductions, use the
+ // raw symbol name to match the sym->name().ToString() lookup path.
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ const auto &kindMap = builder.getKindMap();
+ mlir::Type redType = isByRef
+ ? static_cast<mlir::Type>(fir::ReferenceType::get(reductionType))
+ : reductionType;
+
+ std::string reductionNameStr = Fortran::common::visit(
+ common::visitors{
+ [&](const parser::ProcedureDesignator &pd) -> std::string {
+ // User-defined named reductions: use raw name to match
+ // lookupSymbol in processReductionArguments.
+ return std::get<parser::Name>(pd.u).ToString();
+ },
+ [&](const parser::DefinedOperator &defOp) -> std::string {
+ return Fortran::common::visit(
+ common::visitors{
+ [&](const parser::DefinedOpName &opName) -> std::string {
+ // User-defined operator reductions: use raw name.
+ return opName.v.ToString();
+ },
+ [&](parser::DefinedOperator::IntrinsicOperator intrOp)
+ -> std::string {
+ // Intrinsic operators: use canonical naming to match
+ // processReductionArguments lookup.
+ auto redId = getRedIdFromParserIntrOp(intrOp);
+ return ReductionProcessor::getReductionName(
+ redId, kindMap, redType, isByRef);
+ },
+ },
+ defOp.u);
+ },
+ },
+ identifier.u);
+
ReductionProcessor::createDeclareReductionHelper<
mlir::omp::DeclareReductionOp>(
- converter, reductionNameStr, reductionType,
+ converter, reductionNameStr, redType,
converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
}
}
diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index e0cba4c512258..1a6f372a10faa 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -642,11 +642,11 @@ OpType ReductionProcessor::createDeclareReduction(
bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) {
if (forceByrefReduction)
return true;
-
- if (!fir::isa_trivial(fir::unwrapRefType(reductionType)) &&
- !fir::isa_derived(fir::unwrapRefType(reductionType)))
+ // Non-trivial, non-derived types (e.g., boxes, arrays) must be by-ref.
+ // Derived types must also be by-ref because user-defined combiners
+ // operate on components via side-effects, not by producing a whole value.
+ if (!fir::isa_trivial(fir::unwrapRefType(reductionType)))
return true;
-
return false;
}
@@ -798,6 +798,17 @@ bool ReductionProcessor::processReductionArguments(
}
reductionName = getReductionName(redId, kindMap, redType, isByRef);
+ // If a user-defined declare reduction already exists for this
+ // operator+type, reuse it instead of generating a new one
+ // (which would fail for non-predefined types like derived types).
+ mlir::ModuleOp module = builder.getModule();
+ if (auto existingDecl =
+ module.lookupSymbol<OpType>(reductionName)) {
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ builder.getContext(), existingDecl.getSymName()));
+ ++idx;
+ continue;
+ }
} else if (const auto *reductionIntrinsic =
std::get_if<omp::clause::ProcedureDesignator>(
&redOperator.u)) {
diff --git a/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90 b/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
index 8b5051b63afd4..454c0b428988c 100644
--- a/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
+++ b/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
@@ -1,10 +1,9 @@
-! RUN: not %flang_fc1 -emit-mlir -fopenmp %s -o - 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-mlir -fopenmp %s -o - | FileCheck %s
program test
type t
integer :: x
end type t
- ! CHECK: not yet implemented: Reduction of some types is not supported
!$omp declare reduction(+:t: omp_out%x = omp_out%x + omp_in%x) initializer(omp_priv = t(0))
type(t) :: a
a = t(0)
@@ -12,3 +11,13 @@ program test
a%x = a%x + 1
!$omp end parallel
end program test
+
+! CHECK: omp.declare_reduction @add_reduction_byref_rec__QFTt : !fir.ref<!fir.type<_QFTt{x:i32}>>
+! CHECK-SAME: attributes {byref_element_type = !fir.type<_QFTt{x:i32}>}
+! CHECK: alloc {
+! CHECK: omp.yield
+! CHECK: } init {
+! CHECK: omp.yield
+! CHECK: } combiner {
+! CHECK: omp.yield
+! CHECK: }
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
index ff70acbb10e32..4ca735d6105f1 100644
--- a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
@@ -1,5 +1,5 @@
! This test checks lowering of OpenMP declare reduction Directive, with initialization
-! via a subroutine. This functionality is currently not implemented.
+! via a subroutine.
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
module maxtype_mod
@@ -41,35 +41,34 @@ function func(x, n, init)
end function func
end module maxtype_mod
-!CHECK: omp.declare_reduction @red_add_max : [[MAXTYPE:.*]] init {
-!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: [[MAXTYPE]]):
-!CHECK: %[[OMP_PRIV:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK: %[[OMP_ORIG:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]>
-!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<[[MAXTYPE]]>
-!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK: fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
-!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
-!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : [[MAXTYPE]])
+!CHECK: omp.declare_reduction @red_add_max : !fir.ref<{{.*}}> attributes {byref_element_type = {{.*}}} alloc {
+!CHECK: %[[ALLOCA:.*]] = fir.alloca [[MAXTYPE:.*]]
+!CHECK: omp.yield(%[[ALLOCA]] : !fir.ref<[[MAXTYPE]]>)
+!CHECK: } init {
+!CHECK: ^bb0(%[[INIT_ARG0:.*]]: !fir.ref<[[MAXTYPE]]>, %[[INIT_ARG1:.*]]: !fir.ref<[[MAXTYPE]]>):
+!CHECK: %{{.*}} = fir.embox %[[INIT_ARG1]]
+!CHECK: %{{.*}} = fir.embox %[[INIT_ARG0]]
+!CHECK: %[[OMP_ORIG:.*]]:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[OMP_PRIV:.*]]:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV]]#0, %[[OMP_ORIG]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
+!CHECK: %{{.*}} = fir.load %[[OMP_PRIV]]#0 : !fir.ref<[[MAXTYPE]]>
+!CHECK: omp.yield(%[[INIT_ARG1]] : !fir.ref<[[MAXTYPE]]>)
!CHECK: } combiner {
-!CHECK: ^bb0(%[[LHS_ARG:.*]]: [[MAXTYPE]], %[[RHS_ARG:.*]]: [[MAXTYPE]]):
+!CHECK: ^bb0(%[[LHS_ARG:.*]]: !fir.ref<[[MAXTYPE]]>, %[[RHS_ARG:.*]]: !fir.ref<[[MAXTYPE]]>):
!CHECK: %[[RESULT:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = ".result"}
-!CHECK: %[[OMP_OUT:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK: %[[OMP_IN:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<[[MAXTYPE]]>
-!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<[[MAXTYPE]]>
-!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[OMP_IN:.*]]:2 = hlfir.declare %[[RHS_ARG]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[OMP_OUT:.*]]:2 = hlfir.declare %[[LHS_ARG]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %[[TMPRESULT:.*]]:2 = hlfir.declare %[[RESULT]] {uniq_name = ".tmp.func_result"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK: %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
+!CHECK: %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT]]#0, %[[OMP_IN]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
!CHECK: fir.save_result %[[COMBINE_RESULT]] to %[[TMPRESULT]]#0 : [[MAXTYPE]], !fir.ref<[[MAXTYPE]]>
!CHECK: %false = arith.constant false
!CHECK: %[[EXPRRESULT:.*]] = hlfir.as_expr %[[TMPRESULT]]#0 move %false : (!fir.ref<[[MAXTYPE]]>, i1) -> !hlfir.expr<[[MAXTYPE]]>
!CHECK: %[[ASSOCIATE:.*]]:3 = hlfir.associate %[[EXPRRESULT]] {adapt.valuebyref} : (!hlfir.expr<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>, i1)
!CHECK: %[[RESULT_VAL:.*]] = fir.load %[[ASSOCIATE]]#0 : !fir.ref<[[MAXTYPE]]>
!CHECK: hlfir.end_associate %[[ASSOCIATE]]#1, %[[ASSOCIATE]]#2 : !fir.ref<[[MAXTYPE]]>, i1
-!CHECK: omp.yield(%[[RESULT_VAL]] : [[MAXTYPE]])
+!CHECK: hlfir.destroy %[[EXPRRESULT]] : !hlfir.expr<[[MAXTYPE]]>
+!CHECK: fir.store %[[RESULT_VAL]] to %[[LHS_ARG]] : !fir.ref<[[MAXTYPE]]>
+!CHECK: omp.yield(%[[LHS_ARG]] : !fir.ref<[[MAXTYPE]]>)
!CHECK: }
!CHECK: func.func @_QMmaxtype_modPinitme(%[[X_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "x"}, %[[N_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "n"}) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/184897
More information about the flang-commits
mailing list