[flang-commits] [flang] [flang][OpenMP] Support user-defined declare reduction with derived types (PR #184897)

via flang-commits flang-commits at lists.llvm.org
Fri Mar 13 14:01:56 PDT 2026


https://github.com/MattPD updated https://github.com/llvm/llvm-project/pull/184897

>From 60209b488b62ef911bd51ac7761d2816fc006907 Mon Sep 17 00:00:00 2001
From: "Matt P. Dziubinski" <matt-p.dziubinski at hpe.com>
Date: Thu, 5 Mar 2026 15:56:29 -0600
Subject: [PATCH 1/3] [flang][OpenMP] Support user-defined declare reduction
 with derived types

Fix lowering of `!$omp declare reduction` for intrinsic operators applied
to user-defined derived types (e.g., `+` on `type(t)`). Previously, this
hit a TODO in `ReductionProcessor::getReductionInitValue` because the code
tried to compute an init value for a non-predefined type, when it should
instead use the initializer region from the `DeclareReductionOp`.

The root cause was a naming mismatch: `genOMP` for
`OpenMPDeclareReductionConstruct` used a raw operator string (e.g., "Add")
as the reduction name, while `processReductionArguments` at the use site
computed a canonical name via `getReductionName` (e.g.,
"add_reduction_byref_rec__QFTt"). The `lookupSymbol` in
`createDeclareReductionHelper` never found the already-created op, so it
fell through to `createDeclareReduction` which called `getReductionInitValue`
with the derived type and hit the TODO.

The fix has three parts:

1. Consistent names: In `genOMP` for `OpenMPDeclareReductionConstruct`, compute
the reduction name using the same `getReductionName` scheme that
`processReductionArguments` uses, so both sites produce identical symbol names.
For intrinsic operators, this maps through `ReductionIdentifier` to get the
canonical name. For user-defined named reductions, the raw symbol name is used
directly, matching the existing custom-reduction lookup path.

2. Reuse reduction: In `processReductionArguments`, when an intrinsic operator
reduction is requested, check whether a user-defined declare reduction already
exists under that canonical name before attempting to create a new one. If
found, reuse it. This avoids calling `createDeclareReduction` (and thus
`getReductionInitValue`) for types that have user-provided initializers.

3. Reference semantics: Change `doReductionByRef` to return true for derived
types. Previously it returned false for both trivial and derived types, treating
derived types as by-val. This is incorrect for user-defined combiners that
operate on components via side-effects (e.g., `omp_out%x = omp_out%x +
omp_in%x`): the combiner mutates `omp_out` in place and doesn't produce a
whole-struct value, so `convertExprToValue` returns the component type
(`i32`) rather than the struct type, causing a type mismatch in the
`omp.yield`. By-ref is the correct model: the combiner stores into the
lhs reference and yields it.

The combiner callback in `processReductionCombiner` is also updated to
handle the by-ref derived-type case: when the combiner result type
doesn't match the element type (as happens with component-level
assignments), the store is skipped since the assignment already wrote
into omp_out as a side-effect, and only the lhs reference is yielded.

Tests updates:
- Update declare-reduction-intrinsic-op.f90 from a negative test (checking
for the TODO error) to a positive test checking the generated MLIR.
- Update omp-declare-reduction-derivedtype.f90 CHECK lines to match the
reference semantics fix: the `declare_reduction` now has type `!fir.ref<...>`
with a `byref_element_type` attribute, an alloc region, a two-argument init
region, and a combiner that stores into the lhs and yields the
reference. The function body checks for initme and mycombine are
unchanged in substance but use literal type names instead of a regex
capture to avoid greedy matching issues with nested angle brackets.

Remaining work: declare reduction without an initializer clause is not yet
supported. I plan to address that subsequently.

Assisted-by: Claude Opus 4.6.
---
 flang/include/flang/Lower/OpenMP/Clauses.h    | 11 +++
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 73 ++++++++++++-------
 .../lib/Lower/Support/ReductionProcessor.cpp  | 18 ++++-
 .../OpenMP/declare-reduction-intrinsic-op.f90 | 23 +++++-
 .../omp-declare-reduction-derivedtype.f90     | 41 +++++------
 5 files changed, 113 insertions(+), 53 deletions(-)

diff --git a/flang/include/flang/Lower/OpenMP/Clauses.h b/flang/include/flang/Lower/OpenMP/Clauses.h
index a325e74327240..f334374280c73 100644
--- a/flang/include/flang/Lower/OpenMP/Clauses.h
+++ b/flang/include/flang/Lower/OpenMP/Clauses.h
@@ -329,6 +329,17 @@ using UsesAllocators = tomp::clause::UsesAllocatorsT<TypeTy, IdTy, ExprTy>;
 using Weak = tomp::clause::WeakT<TypeTy, IdTy, ExprTy>;
 using When = tomp::clause::WhenT<TypeTy, IdTy, ExprTy>;
 using Write = tomp::clause::WriteT<TypeTy, IdTy, ExprTy>;
+
+DefinedOperator makeDefinedOperator(const parser::DefinedOperator &inp,
+                                    semantics::SemanticsContext &semaCtx);
+
+ProcedureDesignator
+makeProcedureDesignator(const parser::ProcedureDesignator &inp,
+                        semantics::SemanticsContext &semaCtx);
+
+ReductionOperator
+makeReductionOperator(const parser::OmpReductionIdentifier &inp,
+                      semantics::SemanticsContext &semaCtx);
 } // namespace clause
 
 using tomp::type::operator==;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index e2018add11206..818379d5a47b3 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3864,7 +3864,13 @@ static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
         evalExpr.u);
     stmtCtx.finalizeAndPop();
     if (isByRef) {
-      fir::StoreOp::create(builder, loc, result, lhs);
+      // For user-defined combiners the assignment expression (e.g.
+      // "omp_out%x = omp_out%x + omp_in%x") already wrote into omp_out
+      // as a side-effect. We only need to yield the lhs reference.
+      // Only store result back if its type actually matches the element type.
+      mlir::Type eleTy = fir::unwrapRefType(lhs.getType());
+      if (result.getType() == eleTy)
+        fir::StoreOp::create(builder, loc, result, lhs);
       mlir::omp::YieldOp::create(builder, loc, lhs);
     } else {
       mlir::omp::YieldOp::create(builder, loc, result);
@@ -3957,41 +3963,56 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
   const auto &identifier =
       std::get<parser::OmpReductionIdentifier>(specifier.t);
 
-  std::string reductionNameStr = Fortran::common::visit(
-      common::visitors{
-          [](const parser::ProcedureDesignator &pd) -> std::string {
-            return std::get<parser::Name>(pd.u).ToString();
-          },
-          [](const parser::DefinedOperator &defOp) -> std::string {
-            return Fortran::common::visit(
-                common::visitors{
-                    [](const parser::DefinedOpName &opName) -> std::string {
-                      return opName.v.ToString();
-                    },
-                    [](parser::DefinedOperator::IntrinsicOperator intrOp)
-                        -> std::string {
-                      return std::string(
-                          parser::DefinedOperator::EnumToString(intrOp));
-                    },
-                },
-                defOp.u);
-          },
-      },
-      identifier.u);
+  // Convert the parser-level reduction identifier to the clause-level
+  // representation, then use ReductionProcessor to derive the canonical name.
+  clause::ReductionOperator redOp =
+      clause::makeReductionOperator(identifier, semaCtx);
 
   for (const auto &typeSpec : typeNameList.v) {
     (void)typeSpec; // Currently unused
     mlir::Type reductionType = getReductionType(converter, specifier);
+    bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+    // Compute the canonical reduction name the same way
+    // processReductionArguments does.
+    std::string reductionNameStr = Fortran::common::visit(
+        common::visitors{
+            [&](const clause::DefinedOperator &defOp) -> std::string {
+              return Fortran::common::visit(
+                  common::visitors{
+                      [&](const clause::DefinedOperator::IntrinsicOperator
+                              &intrOp) -> std::string {
+                        ReductionProcessor::ReductionIdentifier redId =
+                            ReductionProcessor::getReductionType(intrOp);
+                        return ReductionProcessor::getReductionName(
+                            redId, converter.getFirOpBuilder().getKindMap(),
+                            reductionType, isByRef);
+                      },
+                      [&](const clause::DefinedOperator::DefinedOpName &opName)
+                          -> std::string {
+                        return opName.v.sym()->name().ToString();
+                      },
+                  },
+                  defOp.u);
+            },
+            [&](const clause::ProcedureDesignator &pd) -> std::string {
+              return pd.v.sym()->name().ToString();
+            },
+        },
+        redOp.u);
+
     ReductionProcessor::GenCombinerCBTy genCombinerCB =
         processReductionCombiner(converter, symTable, semaCtx, combiner);
     ReductionProcessor::GenInitValueCBTy genInitValueCB;
     ClauseProcessor cp(converter, semaCtx, clauses);
     cp.processInitializer(symTable, genInitValueCB);
-    bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+    mlir::Type redType =
+        isByRef
+            ? static_cast<mlir::Type>(fir::ReferenceType::get(reductionType))
+            : reductionType;
     ReductionProcessor::createDeclareReductionHelper<
-        mlir::omp::DeclareReductionOp>(
-        converter, reductionNameStr, reductionType,
-        converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
+        mlir::omp::DeclareReductionOp>(converter, reductionNameStr, redType,
+                                       converter.getCurrentLocation(), isByRef,
+                                       genCombinerCB, genInitValueCB);
   }
 }
 
diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index e0cba4c512258..eaaf643ec7eeb 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -642,11 +642,11 @@ OpType ReductionProcessor::createDeclareReduction(
 bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) {
   if (forceByrefReduction)
     return true;
-
-  if (!fir::isa_trivial(fir::unwrapRefType(reductionType)) &&
-      !fir::isa_derived(fir::unwrapRefType(reductionType)))
+  // Non-trivial, non-derived types (e.g., boxes, arrays) must be by-ref.
+  // Derived types must also be by-ref because user-defined combiners
+  // operate on components via side-effects, not by producing a whole value.
+  if (!fir::isa_trivial(fir::unwrapRefType(reductionType)))
     return true;
-
   return false;
 }
 
@@ -798,6 +798,16 @@ bool ReductionProcessor::processReductionArguments(
         }
 
         reductionName = getReductionName(redId, kindMap, redType, isByRef);
+        // If a user-defined declare reduction already exists for this
+        // operator+type, reuse it instead of generating a new one
+        // (which would fail for non-predefined types like derived types).
+        mlir::ModuleOp module = builder.getModule();
+        if (auto existingDecl = module.lookupSymbol<OpType>(reductionName)) {
+          reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+              builder.getContext(), existingDecl.getSymName()));
+          ++idx;
+          continue;
+        }
       } else if (const auto *reductionIntrinsic =
                      std::get_if<omp::clause::ProcedureDesignator>(
                          &redOperator.u)) {
diff --git a/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90 b/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
index 8b5051b63afd4..a2ec52902bba2 100644
--- a/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
+++ b/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
@@ -1,10 +1,9 @@
-! RUN: not %flang_fc1 -emit-mlir -fopenmp %s -o - 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
 
 program test
   type t
      integer :: x
   end type t
-  ! CHECK: not yet implemented: Reduction of some types is not supported
   !$omp declare reduction(+:t: omp_out%x = omp_out%x + omp_in%x) initializer(omp_priv = t(0))
   type(t) :: a
   a = t(0)
@@ -12,3 +11,23 @@ program test
   a%x = a%x + 1
   !$omp end parallel
 end program test
+
+! CHECK: omp.declare_reduction @add_reduction_byref_rec__QFTt :
+! CHECK:   %[[ALLOCA:.*]] = fir.alloca [[TY:.*]]
+! CHECK:   omp.yield(%[[ALLOCA]] : !fir.ref<[[TY]]>)
+! CHECK: } init {
+! CHECK: ^bb0(%[[INIT_ARG0:.*]]: !fir.ref<[[TY]]>, %[[INIT_ARG1:.*]]: !fir.ref<[[TY]]>):
+! CHECK:   %{{.*}} = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_orig"}
+! CHECK:   %{{.*}} = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_priv"}
+! CHECK:   omp.yield(%[[INIT_ARG1]] : !fir.ref<[[TY]]>)
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<[[TY]]>, %[[ARG1:.*]]: !fir.ref<[[TY]]>):
+! CHECK:   %[[OMP_IN:.*]]:2 = hlfir.declare %[[ARG1]] {uniq_name = "omp_in"}
+! CHECK:   %[[OMP_OUT:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "omp_out"}
+! CHECK:   %[[OUT_X:.*]] = hlfir.designate %[[OMP_OUT]]#0{"x"} : (!fir.ref<[[TY]]>) -> !fir.ref<i32>
+! CHECK:   %[[OUT_X_VAL:.*]] = fir.load %[[OUT_X]] : !fir.ref<i32>
+! CHECK:   %[[IN_X:.*]] = hlfir.designate %[[OMP_IN]]#0{"x"} : (!fir.ref<[[TY]]>) -> !fir.ref<i32>
+! CHECK:   %[[IN_X_VAL:.*]] = fir.load %[[IN_X]] : !fir.ref<i32>
+! CHECK:   %{{.*}} = arith.addi %[[OUT_X_VAL]], %[[IN_X_VAL]] : i32
+! CHECK:   omp.yield(%[[ARG0]] : !fir.ref<[[TY]]>)
+! CHECK: }
\ No newline at end of file
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
index ff70acbb10e32..4ca735d6105f1 100644
--- a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
@@ -1,5 +1,5 @@
 ! This test checks lowering of OpenMP declare reduction Directive, with initialization
-! via a subroutine. This functionality is currently not implemented.
+! via a subroutine.
 
 !RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
 module maxtype_mod
@@ -41,35 +41,34 @@ function func(x, n, init)
   end function func
 
 end module maxtype_mod
-!CHECK:  omp.declare_reduction @red_add_max : [[MAXTYPE:.*]] init {
-!CHECK:  ^bb0(%[[OMP_ORIG_ARG_I:.*]]: [[MAXTYPE]]):
-!CHECK:    %[[OMP_PRIV:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK:    %[[OMP_ORIG:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK:    fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]>
-!CHECK:    %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK:    fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<[[MAXTYPE]]>
-!CHECK:    %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK:    fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
-!CHECK:    %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
-!CHECK:    omp.yield(%[[OMP_PRIV_VAL]] : [[MAXTYPE]])
+!CHECK:  omp.declare_reduction @red_add_max : !fir.ref<{{.*}}> attributes {byref_element_type = {{.*}}} alloc {
+!CHECK:  %[[ALLOCA:.*]] = fir.alloca [[MAXTYPE:.*]]
+!CHECK:  omp.yield(%[[ALLOCA]] : !fir.ref<[[MAXTYPE]]>)
+!CHECK:  } init {
+!CHECK:  ^bb0(%[[INIT_ARG0:.*]]: !fir.ref<[[MAXTYPE]]>, %[[INIT_ARG1:.*]]: !fir.ref<[[MAXTYPE]]>):
+!CHECK:    %{{.*}} = fir.embox %[[INIT_ARG1]]
+!CHECK:    %{{.*}} = fir.embox %[[INIT_ARG0]]
+!CHECK:    %[[OMP_ORIG:.*]]:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %[[OMP_PRIV:.*]]:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV]]#0, %[[OMP_ORIG]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
+!CHECK:    %{{.*}} = fir.load %[[OMP_PRIV]]#0 : !fir.ref<[[MAXTYPE]]>
+!CHECK:    omp.yield(%[[INIT_ARG1]] : !fir.ref<[[MAXTYPE]]>)
 !CHECK:  } combiner {
-!CHECK:  ^bb0(%[[LHS_ARG:.*]]: [[MAXTYPE]], %[[RHS_ARG:.*]]: [[MAXTYPE]]):
+!CHECK:  ^bb0(%[[LHS_ARG:.*]]: !fir.ref<[[MAXTYPE]]>, %[[RHS_ARG:.*]]: !fir.ref<[[MAXTYPE]]>):
 !CHECK:    %[[RESULT:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = ".result"}
-!CHECK:    %[[OMP_OUT:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK:    %[[OMP_IN:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK:    fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<[[MAXTYPE]]>
-!CHECK:    %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK:    fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<[[MAXTYPE]]>
-!CHECK:    %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %[[OMP_IN:.*]]:2 = hlfir.declare %[[RHS_ARG]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %[[OMP_OUT:.*]]:2 = hlfir.declare %[[LHS_ARG]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
 !CHECK:    %[[TMPRESULT:.*]]:2 = hlfir.declare %[[RESULT]] {uniq_name = ".tmp.func_result"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK:    %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
+!CHECK:    %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT]]#0, %[[OMP_IN]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
 !CHECK:    fir.save_result %[[COMBINE_RESULT]] to %[[TMPRESULT]]#0 : [[MAXTYPE]], !fir.ref<[[MAXTYPE]]>
 !CHECK:    %false = arith.constant false
 !CHECK:    %[[EXPRRESULT:.*]] = hlfir.as_expr %[[TMPRESULT]]#0 move %false : (!fir.ref<[[MAXTYPE]]>, i1) -> !hlfir.expr<[[MAXTYPE]]>
 !CHECK:    %[[ASSOCIATE:.*]]:3 = hlfir.associate %[[EXPRRESULT]] {adapt.valuebyref} : (!hlfir.expr<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>, i1)
 !CHECK:    %[[RESULT_VAL:.*]] = fir.load %[[ASSOCIATE]]#0 : !fir.ref<[[MAXTYPE]]>
 !CHECK:    hlfir.end_associate %[[ASSOCIATE]]#1, %[[ASSOCIATE]]#2 : !fir.ref<[[MAXTYPE]]>, i1
-!CHECK:    omp.yield(%[[RESULT_VAL]] : [[MAXTYPE]])
+!CHECK:    hlfir.destroy %[[EXPRRESULT]] : !hlfir.expr<[[MAXTYPE]]>
+!CHECK:    fir.store %[[RESULT_VAL]] to %[[LHS_ARG]] : !fir.ref<[[MAXTYPE]]>
+!CHECK:    omp.yield(%[[LHS_ARG]] : !fir.ref<[[MAXTYPE]]>)
 !CHECK:  }
 
 !CHECK:  func.func @_QMmaxtype_modPinitme(%[[X_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "x"}, %[[N_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "n"}) {

>From 420922117891f9c301b42fb1dc1aee686f1f840f Mon Sep 17 00:00:00 2001
From: "Matt P. Dziubinski" <matt-p.dziubinski at hpe.com>
Date: Fri, 13 Mar 2026 15:20:14 -0500
Subject: [PATCH 2/3] WIP: whole-variable assignment

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    | 30 ++++++++++---
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 43 ++++++++++++++-----
 .../OpenMP/declare-reduction-intrinsic-op.f90 | 11 +++--
 .../omp-declare-reduction-derivedtype.f90     | 10 ++---
 4 files changed, 69 insertions(+), 25 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 85493bf45453e..4fb04929052b1 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -396,6 +396,7 @@ bool ClauseProcessor::processInitializer(
       for (const Object &object :
            std::get<StylizedInstance::Variables>(inst.t)) {
         mlir::Value addr;
+        std::string name = object.sym()->name().ToString();
         mlir::Type ompOrigType = ompOrig.getType();
         // Check for unsupported dynamic-length character reductions
         mlir::Type unwrappedType = fir::unwrapRefType(ompOrigType);
@@ -408,8 +409,16 @@ bool ClauseProcessor::processInitializer(
                  "OpenMP reduction allocation for dynamic length character");
           }
         }
-        // If ompOrig is already a reference, we can use it directly
-        if (fir::isa_ref_type(ompOrigType)) {
+        // For by-ref reductions, the init block has two arguments:
+        //   arg0 = mold/original, arg1 = private allocation.
+        // omp_priv must map to arg1 (the private copy), not arg0.
+        if (name == "omp_priv" && fir::isa_ref_type(ompOrigType)) {
+          mlir::Block *initBlock = builder.getInsertionBlock();
+          if (initBlock->getNumArguments() > 1)
+            addr = initBlock->getArgument(1);
+          else
+            addr = ompOrig;
+        } else if (fir::isa_ref_type(ompOrigType)) {
           addr = ompOrig;
         } else {
           addr = builder.createTemporary(loc, ompOrigType);
@@ -419,7 +428,6 @@ bool ClauseProcessor::processInitializer(
         fir::FortranVariableFlagsAttr attributes =
             Fortran::lower::translateSymbolAttributes(
                 builder.getContext(), *object.sym(), extraFlags);
-        std::string name = object.sym()->name().ToString();
         // Get length parameters for types that need them (e.g., characters).
         // Note: DeclareOp requires exactly one type parameter for non-boxed
         // characters, unlike EmboxOp which doesn't allow them for constant-len.
@@ -451,13 +459,23 @@ bool ClauseProcessor::processInitializer(
               [&](const auto &expr) -> mlir::Value {
                 mlir::Value exprResult = fir::getBase(convertExprToValue(
                     loc, converter, initExpr, symMap, stmtCtx));
-                // Conversion can either give a value or a refrence to a value,
-                // we need to return the reduction type, so an optional load may
-                // be generated.
                 if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
                         exprResult.getType()))
                   if (ompPrivVar.getType() == refType)
                     exprResult = fir::LoadOp::create(builder, loc, exprResult);
+
+                // For derived types in by-ref reductions, the init value
+                // (e.g. t(0)) must be stored into omp_priv explicitly.
+                // populateByRefInitAndCleanupRegions doesn't handle
+                // scalarInitValue for unboxed derived types, so we store
+                // here and return null to prevent a redundant store attempt.
+                if (ompPrivVar &&
+                    fir::isa_ref_type(ompPrivVar.getType()) &&
+                    fir::isa_derived(
+                        fir::unwrapRefType(ompPrivVar.getType()))) {
+                  fir::StoreOp::create(builder, loc, exprResult, ompPrivVar);
+                  return mlir::Value{};
+                }
                 return exprResult;
               }},
           initExpr.u);
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 818379d5a47b3..042fb37743327 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3839,6 +3839,10 @@ static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
     }
 
     lower::StatementContext stmtCtx;
+    mlir::Type eleTy = isByRef ? fir::unwrapRefType(lhs.getType())
+                               : lhs.getType();
+    bool isDerived = fir::isa_derived(eleTy);
+
     mlir::Value result = common::visit(
         common::visitors{
             [&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
@@ -3848,29 +3852,48 @@ static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
               return outVal;
             },
             [&](const auto &expr) -> mlir::Value {
+              if (isDerived && isByRef) {
+                mlir::Value exprResult = fir::getBase(convertExprToValue(
+                    loc, converter, evalExpr, symTable, stmtCtx));
+                // Check if this is a component-level expression (has a
+                // designator on omp_out) or a whole-struct expression
+                // (function call returning a struct).
+                mlir::Value designator;
+                for (auto *user : ompOutVar.getUsers()) {
+                  if (auto desig = mlir::dyn_cast<hlfir::DesignateOp>(user))
+                    designator = desig.getResult();
+                }
+                if (designator) {
+                  // Component-level: store into the designator.
+                  fir::StoreOp::create(builder, loc, exprResult, designator);
+                  return mlir::Value{};
+                }
+                // Whole-struct function result: return the value so the
+                // post-visit code can store it into lhs.
+                if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
+                        exprResult.getType()))
+                  exprResult = fir::LoadOp::create(builder, loc, exprResult);
+                return exprResult;
+              }
               mlir::Value exprResult = fir::getBase(convertExprToValue(
                   loc, converter, evalExpr, symTable, stmtCtx));
-              // Optional load may be generated if we get a reference to the
-              // reduction type.
               if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
                       exprResult.getType())) {
                 mlir::Type expectedType =
-                    isByRef ? fir::unwrapRefType(lhs.getType()) : lhs.getType();
+                    isByRef ? fir::unwrapRefType(lhs.getType())
+                            : lhs.getType();
                 if (expectedType == refType.getElementType())
-                  exprResult = fir::LoadOp::create(builder, loc, exprResult);
+                  exprResult =
+                      fir::LoadOp::create(builder, loc, exprResult);
               }
               return exprResult;
             }},
         evalExpr.u);
     stmtCtx.finalizeAndPop();
     if (isByRef) {
-      // For user-defined combiners the assignment expression (e.g.
-      // "omp_out%x = omp_out%x + omp_in%x") already wrote into omp_out
-      // as a side-effect. We only need to yield the lhs reference.
-      // Only store result back if its type actually matches the element type.
-      mlir::Type eleTy = fir::unwrapRefType(lhs.getType());
-      if (result.getType() == eleTy)
+      if (result) {
         fir::StoreOp::create(builder, loc, result, lhs);
+      }
       mlir::omp::YieldOp::create(builder, loc, lhs);
     } else {
       mlir::omp::YieldOp::create(builder, loc, result);
diff --git a/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90 b/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
index a2ec52902bba2..1ce41346d5468 100644
--- a/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
+++ b/flang/test/Lower/OpenMP/declare-reduction-intrinsic-op.f90
@@ -17,8 +17,10 @@ end program test
 ! CHECK:   omp.yield(%[[ALLOCA]] : !fir.ref<[[TY]]>)
 ! CHECK: } init {
 ! CHECK: ^bb0(%[[INIT_ARG0:.*]]: !fir.ref<[[TY]]>, %[[INIT_ARG1:.*]]: !fir.ref<[[TY]]>):
-! CHECK:   %{{.*}} = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_orig"}
-! CHECK:   %{{.*}} = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_priv"}
+! CHECK:   %{{.*}} = fir.embox %[[INIT_ARG1]]
+! CHECK:   %{{.*}} = fir.embox %[[INIT_ARG0]]
+! CHECK:   %{{.*}}:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_orig"}
+! CHECK:   %{{.*}}:2 = hlfir.declare %[[INIT_ARG1]] {uniq_name = "omp_priv"}
 ! CHECK:   omp.yield(%[[INIT_ARG1]] : !fir.ref<[[TY]]>)
 ! CHECK: } combiner {
 ! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<[[TY]]>, %[[ARG1:.*]]: !fir.ref<[[TY]]>):
@@ -28,6 +30,7 @@ end program test
 ! CHECK:   %[[OUT_X_VAL:.*]] = fir.load %[[OUT_X]] : !fir.ref<i32>
 ! CHECK:   %[[IN_X:.*]] = hlfir.designate %[[OMP_IN]]#0{"x"} : (!fir.ref<[[TY]]>) -> !fir.ref<i32>
 ! CHECK:   %[[IN_X_VAL:.*]] = fir.load %[[IN_X]] : !fir.ref<i32>
-! CHECK:   %{{.*}} = arith.addi %[[OUT_X_VAL]], %[[IN_X_VAL]] : i32
+! CHECK:   %[[ADD:.*]] = arith.addi %[[OUT_X_VAL]], %[[IN_X_VAL]] : i32
+! CHECK:   fir.store %[[ADD]] to %[[OUT_X]] : !fir.ref<i32>
 ! CHECK:   omp.yield(%[[ARG0]] : !fir.ref<[[TY]]>)
-! CHECK: }
\ No newline at end of file
+! CHECK: }
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
index 4ca735d6105f1..0ab32bf5a9be6 100644
--- a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
@@ -41,17 +41,17 @@ function func(x, n, init)
   end function func
 
 end module maxtype_mod
-!CHECK:  omp.declare_reduction @red_add_max : !fir.ref<{{.*}}> attributes {byref_element_type = {{.*}}} alloc {
+!CHECK:  omp.declare_reduction @red_add_max : !fir.ref<[[MAXTYPE:.*]]> {{.*}} alloc {
 !CHECK:  %[[ALLOCA:.*]] = fir.alloca [[MAXTYPE:.*]]
 !CHECK:  omp.yield(%[[ALLOCA]] : !fir.ref<[[MAXTYPE]]>)
 !CHECK:  } init {
 !CHECK:  ^bb0(%[[INIT_ARG0:.*]]: !fir.ref<[[MAXTYPE]]>, %[[INIT_ARG1:.*]]: !fir.ref<[[MAXTYPE]]>):
 !CHECK:    %{{.*}} = fir.embox %[[INIT_ARG1]]
 !CHECK:    %{{.*}} = fir.embox %[[INIT_ARG0]]
-!CHECK:    %[[OMP_ORIG:.*]]:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK:    %[[OMP_PRIV:.*]]:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK:    fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV]]#0, %[[OMP_ORIG]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
-!CHECK:    %{{.*}} = fir.load %[[OMP_PRIV]]#0 : !fir.ref<[[MAXTYPE]]>
+!CHECK:    %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[INIT_ARG0]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[INIT_ARG1]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
+!CHECK:    %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
 !CHECK:    omp.yield(%[[INIT_ARG1]] : !fir.ref<[[MAXTYPE]]>)
 !CHECK:  } combiner {
 !CHECK:  ^bb0(%[[LHS_ARG:.*]]: !fir.ref<[[MAXTYPE]]>, %[[RHS_ARG:.*]]: !fir.ref<[[MAXTYPE]]>):

>From cf482e21d0058b4c0013935ccaa2300f3bf6af0c Mon Sep 17 00:00:00 2001
From: "Matt P. Dziubinski" <matt-p.dziubinski at hpe.com>
Date: Fri, 13 Mar 2026 15:27:34 -0500
Subject: [PATCH 3/3] Whole & component assignment

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 4fb04929052b1..da902efdad679 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -465,7 +465,7 @@ bool ClauseProcessor::processInitializer(
                     exprResult = fir::LoadOp::create(builder, loc, exprResult);
 
                 // For derived types in by-ref reductions, the init value
-                // (e.g. t(0)) must be stored into omp_priv explicitly.
+                // must be stored into omp_priv explicitly.
                 // populateByRefInitAndCleanupRegions doesn't handle
                 // scalarInitValue for unboxed derived types, so we store
                 // here and return null to prevent a redundant store attempt.
@@ -473,7 +473,14 @@ bool ClauseProcessor::processInitializer(
                     fir::isa_ref_type(ompPrivVar.getType()) &&
                     fir::isa_derived(
                         fir::unwrapRefType(ompPrivVar.getType()))) {
-                  fir::StoreOp::create(builder, loc, exprResult, ompPrivVar);
+                  // Only store if the expression result type matches the
+                  // whole derived type. For component-level initializers
+                  // (e.g. omp_priv%i=0), the assignment was already
+                  // performed as a side effect during expression lowering.
+                  mlir::Type privEleTy =
+                      fir::unwrapRefType(ompPrivVar.getType());
+                  if (exprResult.getType() == privEleTy)
+                    fir::StoreOp::create(builder, loc, exprResult, ompPrivVar);
                   return mlir::Value{};
                 }
                 return exprResult;



More information about the flang-commits mailing list