[Mlir-commits] [flang] [mlir] [OpenMP][flang] Lowering of OpenMP custom reductions to MLIR (PR #168417)

Jan Leyonberg llvmlistbot at llvm.org
Mon Nov 17 10:14:22 PST 2025


https://github.com/jsjodin created https://github.com/llvm/llvm-project/pull/168417

This patch add support for lowering of custom reductions to MLIR. It also enhances the capability of the pass to automatically mark functions as "declare target" by traversing custom reduction initializers and combiners.

>From f8a24cbfad7db15a6197e48e8f1d2343a1d02869 Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Thu, 18 Sep 2025 10:50:04 -0400
Subject: [PATCH 1/2] [OpenMP][flang] Lowering of OpenMP custom reductions to
 MLIR

This patch add support for lowering of custom reductions to MLIR.
---
 .../flang/Lower/Support/ReductionProcessor.h  |  18 +++
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    |  60 ++++++++
 flang/lib/Lower/OpenMP/ClauseProcessor.h      |   5 +
 flang/lib/Lower/OpenMP/Clauses.cpp            |  17 ++-
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 133 ++++++++++++++++-
 .../lib/Lower/Support/ReductionProcessor.cpp  |  90 ++++++++----
 .../Optimizer/OpenMP/MarkDeclareTarget.cpp    | 139 +++++++++++++-----
 .../Todo/omp-declare-reduction-initsub.f90    |  28 ----
 .../OpenMP/Todo/omp-declare-reduction.f90     |  10 --
 ...are-target-deferred-marking-reductions.f90 |  37 +++++
 .../omp-declare-reduction-derivedtype.f90     | 112 ++++++++++++++
 .../OpenMP/omp-declare-reduction-initsub.f90  |  59 ++++++++
 .../Lower/OpenMP/omp-declare-reduction.f90    |  33 +++++
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  16 +-
 14 files changed, 650 insertions(+), 107 deletions(-)
 delete mode 100644 flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90
 delete mode 100644 flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90
 create mode 100644 flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90
 create mode 100644 flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
 create mode 100644 flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
 create mode 100644 flang/test/Lower/OpenMP/omp-declare-reduction.f90

diff --git a/flang/include/flang/Lower/Support/ReductionProcessor.h b/flang/include/flang/Lower/Support/ReductionProcessor.h
index 66f26b3b55630..bd0447360f089 100644
--- a/flang/include/flang/Lower/Support/ReductionProcessor.h
+++ b/flang/include/flang/Lower/Support/ReductionProcessor.h
@@ -40,6 +40,13 @@ namespace omp {
 
 class ReductionProcessor {
 public:
+  using GenInitValueCBTy =
+      std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
+                                mlir::Type type, mlir::Value ompOrig)>;
+  using GenCombinerCBTy = std::function<void(
+      fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
+      mlir::Value op1, mlir::Value op2, bool isByRef)>;
+
   // TODO: Move this enumeration to the OpenMP dialect
   enum ReductionIdentifier {
     ID,
@@ -58,6 +65,9 @@ class ReductionProcessor {
     IEOR
   };
 
+  static bool doReductionByRef(mlir::Type reductionType);
+  static bool doReductionByRef(mlir::Value reductionVar);
+
   static ReductionIdentifier
   getReductionType(const omp::clause::ProcedureDesignator &pd);
 
@@ -109,6 +119,14 @@ class ReductionProcessor {
                                           ReductionIdentifier redId,
                                           mlir::Type type, mlir::Value op1,
                                           mlir::Value op2);
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The init and combiner regions are generated by the callback
+  /// functions genCombinerCB and genInitValueCB.
+  template <typename DeclareRedType>
+  static DeclareRedType createDeclareReductionHelper(
+      AbstractConverter &converter, llvm::StringRef reductionOpName,
+      mlir::Type type, mlir::Location loc, bool isByRef,
+      GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);
 
   /// Creates an OpenMP reduction declaration and inserts it into the provided
   /// symbol table. The declaration has a constant initializer with the neutral
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index e018a2d937435..fadfb29b07a28 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -13,6 +13,7 @@
 #include "ClauseProcessor.h"
 #include "Utils.h"
 
+#include "flang/Lower/ConvertCall.h"
 #include "flang/Lower/ConvertExprToHLFIR.h"
 #include "flang/Lower/OpenMP/Clauses.h"
 #include "flang/Lower/PFTBuilder.h"
@@ -402,6 +403,65 @@ bool ClauseProcessor::processInclusive(
   return false;
 }
 
+bool ClauseProcessor::processInitializer(
+    lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
+    ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
+  if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
+    genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
+                                 mlir::Type type, mlir::Value ompOrig) {
+      lower::SymMapScope scope(symMap);
+      const parser::OmpInitializerExpression &iexpr = inp.v.v;
+      const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
+      const std::list<parser::OmpStylizedDeclaration> &declList =
+          std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
+      mlir::Value ompPrivVar;
+      for (const parser::OmpStylizedDeclaration &decl : declList) {
+        auto &name = std::get<parser::ObjectName>(decl.var.t);
+        assert(name.symbol && "Name does not have a symbol");
+        mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
+        fir::StoreOp::create(builder, loc, ompOrig, addr);
+        fir::FortranVariableFlagsEnum extraFlags = {};
+        fir::FortranVariableFlagsAttr attributes =
+            Fortran::lower::translateSymbolAttributes(builder.getContext(),
+                                                      *name.symbol, extraFlags);
+        auto declareOp = hlfir::DeclareOp::create(
+            builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
+            0, attributes);
+        if (name.ToString() == "omp_priv")
+          ompPrivVar = declareOp.getResult(0);
+        symMap.addVariableDefinition(*name.symbol, declareOp);
+      }
+      // Lower the expression/function call
+      lower::StatementContext stmtCtx;
+      mlir::Value result = common::visit(
+          common::visitors{
+              [&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
+                convertCallToHLFIR(loc, converter, procRef, std::nullopt,
+                                   symMap, stmtCtx);
+                auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar);
+                return privVal;
+              },
+              [&](const auto &expr) -> mlir::Value {
+                mlir::Value exprResult = fir::getBase(convertExprToValue(
+                    loc, converter, clause->v, symMap, stmtCtx));
+                // Conversion can either give a value or a refrence to a value,
+                // we need to return the reduction type, so an optional load may
+                // be generated.
+                if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
+                        exprResult.getType()))
+                  if (ompPrivVar.getType() == refType)
+                    exprResult = fir::LoadOp::create(builder, loc, exprResult);
+                return exprResult;
+              }},
+          clause->v.u);
+      stmtCtx.finalizeAndPop();
+      return result;
+    };
+    return true;
+  }
+  return false;
+}
+
 bool ClauseProcessor::processMergeable(
     mlir::omp::MergeableClauseOps &result) const {
   return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index d524b4ddc8ac4..85ca9ecf9d98a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -12,12 +12,14 @@
 #ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H
 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
 
+
 #include "ClauseFinder.h"
 #include "Utils.h"
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/OpenMP/Clauses.h"
+#include "flang/Lower/Support/ReductionProcessor.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Parser/dump-parse-tree.h"
 #include "flang/Parser/parse-tree.h"
@@ -88,6 +90,9 @@ class ClauseProcessor {
   bool processHint(mlir::omp::HintClauseOps &result) const;
   bool processInclusive(mlir::Location currentLocation,
                         mlir::omp::InclusiveClauseOps &result) const;
+  bool processInitializer(
+      lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
+      ReductionProcessor::GenInitValueCBTy &genInitValueCB) const;
   bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
   bool processNogroup(mlir::omp::NogroupClauseOps &result) const;
   bool processNowait(mlir::omp::NowaitClauseOps &result) const;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b1a3c3d3c5439..cf8d9a7ee6596 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -981,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp,
 
 Initializer make(const parser::OmpClause::Initializer &inp,
                  semantics::SemanticsContext &semaCtx) {
-  llvm_unreachable("Empty: initializer");
+  const parser::OmpInitializerExpression &iexpr = inp.v.v;
+  const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
+  const parser::OmpStylizedInstance::Instance &instance =
+      std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t);
+  if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
+    auto &expr = std::get<parser::Expr>(as->t);
+    return Initializer{makeExpr(expr, semaCtx)};
+  } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
+    if (call->typedCall) {
+      const auto &procRef = *call->typedCall;
+      semantics::SomeExpr evalProcRef{procRef};
+      return Initializer{evalProcRef};
+    }
+  } else {
+    llvm_unreachable("Unexpected initializer");
+  }
 }
 
 InReduction make(const parser::OmpClause::InReduction &inp,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f822fe3c8dd71..c4b174db8ac22 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -18,12 +18,15 @@
 #include "Decomposer.h"
 #include "Utils.h"
 #include "flang/Common/idioms.h"
+#include "flang/Evaluate/type.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
+#include "flang/Lower/ConvertExprToHLFIR.h"
 #include "flang/Lower/ConvertVariable.h"
 #include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/OpenMP/Clauses.h"
 #include "flang/Lower/StatementContext.h"
+#include "flang/Lower/Support/ReductionProcessor.h"
 #include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
@@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   // TODO: Add private syms and vars.
   args.reduction.syms = reductionSyms;
   args.reduction.vars = clauseOps.reductionVars;
-
   return genOpWithBody<mlir::omp::TeamsOp>(
       OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
                         llvm::omp::Directive::OMPD_teams)
@@ -3563,12 +3565,137 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
     TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
 }
 
+static bool
+processReductionCombiner(lower::AbstractConverter &converter,
+                         lower::SymMap &symTable,
+                         semantics::SemanticsContext &semaCtx,
+                         const parser::OmpReductionSpecifier &specifier,
+                         ReductionProcessor::GenCombinerCBTy &genCombinerCB) {
+  const auto &combinerExpression =
+      std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
+          .value();
+  const parser::OmpStylizedInstance &combinerInstance =
+      combinerExpression.v.front();
+  const parser::OmpStylizedInstance::Instance &instance =
+    std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
+  if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
+    auto &expr = std::get<parser::Expr>(as->t);
+    genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
+                        mlir::Type type, mlir::Value lhs, mlir::Value rhs,
+                        bool isByRef) {
+      const auto &evalExpr = makeExpr(expr, semaCtx);
+      lower::SymMapScope scope(symTable);
+      const std::list<parser::OmpStylizedDeclaration> &declList =
+        std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+      for (const parser::OmpStylizedDeclaration &decl : declList) {
+        auto &name = std::get<parser::ObjectName>(decl.var.t);
+        mlir::Value addr = lhs;
+        mlir::Type type = lhs.getType();
+        bool isRhs = name.ToString() == std::string("omp_in");
+        if (isRhs) {
+          addr = rhs;
+          type = rhs.getType();
+        }
+
+        assert(name.symbol && "Reduction object name does not have a symbol");
+        if (!fir::conformsWithPassByRef(type)) {
+            addr = builder.createTemporary(loc, type);
+            fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
+        }
+        fir::FortranVariableFlagsEnum extraFlags = {};
+        fir::FortranVariableFlagsAttr attributes =
+          Fortran::lower::translateSymbolAttributes(builder.getContext(),
+                                                    *name.symbol, extraFlags);
+        auto declareOp = hlfir::DeclareOp::create(
+            builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
+            0, attributes);
+        symTable.addVariableDefinition(*name.symbol, declareOp);
+      }
+
+      lower::StatementContext stmtCtx;
+      mlir::Value result = fir::getBase(
+          convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
+      if (auto refType =
+          llvm::dyn_cast<fir::ReferenceType>(result.getType()))
+        if (lhs.getType() == refType.getElementType())
+          result = fir::LoadOp::create(builder, loc, result);
+      stmtCtx.finalizeAndPop();
+      if (isByRef) {
+        fir::StoreOp::create(builder, loc, result, lhs);
+        mlir::omp::YieldOp::create(builder, loc, lhs);
+      } else {
+        mlir::omp::YieldOp::create(builder, loc, result);
+      }
+
+      return result;
+    };
+  }
+  return true;
+}
+
+// Getting the type from a symbol compared to a DeclSpec is simpler since we do
+// not need to consider derived vs intrinsic types. Semantics is guaranteed to
+// generate these symbols.
+static mlir::Type
+getReductionType(lower::AbstractConverter &converter,
+                 const parser::OmpReductionSpecifier &specifier) {
+  const auto &combinerExpression =
+      std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
+          .value();
+  const parser::OmpStylizedInstance &combinerInstance =
+      combinerExpression.v.front();
+  const std::list<parser::OmpStylizedDeclaration> &declList =
+      std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+  const parser::OmpStylizedDeclaration &decl = declList.front();
+  const auto &name = std::get<parser::ObjectName>(decl.var.t);
+  const auto &symbol = semantics::SymbolRef(*name.symbol);
+  mlir::Type reductionType = converter.genType(symbol);
+  return reductionType;
+}
+
 static void genOMP(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
     const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
-  if (!semaCtx.langOptions().OpenMPSimd)
-    TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
+  if (!semaCtx.langOptions().OpenMPSimd) {
+    const parser::OmpArgumentList &args{
+        declareReductionConstruct.v.Arguments()};
+    const parser::OmpArgument &arg{args.v.front()};
+    const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
+
+    if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
+      TODO(converter.getCurrentLocation(),
+           "multiple types in declare target is not yet supported");
+
+    mlir::Type reductionType = getReductionType(converter, specifier);
+    ReductionProcessor::GenCombinerCBTy genCombinerCB;
+    processReductionCombiner(converter, symTable, semaCtx, specifier,
+                             genCombinerCB);
+    const parser::OmpClauseList &initializer =
+        declareReductionConstruct.v.Clauses();
+    if (initializer.v.size() > 0) {
+      List<Clause> clauses = makeClauses(initializer, semaCtx);
+      ReductionProcessor::GenInitValueCBTy genInitValueCB;
+      ClauseProcessor cp(converter, semaCtx, clauses);
+      const parser::OmpClause::Initializer &iclause{
+          std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
+      cp.processInitializer(symTable, iclause, genInitValueCB);
+      const auto &identifier =
+          std::get<parser::OmpReductionIdentifier>(specifier.t);
+      const auto &designator =
+          std::get<parser::ProcedureDesignator>(identifier.u);
+      const auto &reductionName = std::get<parser::Name>(designator.u);
+      bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+      ReductionProcessor::createDeclareReductionHelper<
+          mlir::omp::DeclareReductionOp>(
+          converter, reductionName.ToString(), reductionType,
+          converter.getCurrentLocation(), isByRef, genCombinerCB,
+          genInitValueCB);
+    } else {
+      TODO(converter.getCurrentLocation(),
+           "declare target without an initializer clause is not yet supported");
+    }
+  }
 }
 
 static void
diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index 605a5b6b20b94..283e5ea73c319 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -462,7 +462,7 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
                         bool isByRef) {
   ty = fir::unwrapRefType(ty);
 
-  if (fir::isa_trivial(ty)) {
+  if (fir::isa_trivial(ty) || fir::isa_derived(ty)) {
     mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
     mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
 
@@ -501,7 +501,7 @@ static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
 template <typename OpType>
 static void createReductionAllocAndInitRegions(
     AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl,
-    const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
+    ReductionProcessor::GenInitValueCBTy genInitValueCB, mlir::Type type,
     bool isByRef) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); };
@@ -523,9 +523,8 @@ static void createReductionAllocAndInitRegions(
 
   mlir::Type ty = fir::unwrapRefType(type);
   builder.setInsertionPointToEnd(initBlock);
-  mlir::Value initValue = ReductionProcessor::getReductionInitValue(
-      loc, unwrapSeqOrBoxedType(ty), redId, builder);
-
+  mlir::Value initValue =
+      genInitValueCB(builder, loc, ty, initBlock->getArgument(0));
   if (isByRef) {
     populateByRefInitAndCleanupRegions(
         converter, loc, type, initValue, initBlock,
@@ -536,7 +535,7 @@ static void createReductionAllocAndInitRegions(
         /*isDoConcurrent*/ std::is_same_v<OpType, fir::DeclareReductionOp>);
   }
 
-  if (fir::isa_trivial(ty)) {
+  if (fir::isa_trivial(ty) || fir::isa_derived(ty)) {
     if (isByRef) {
       // alloc region
       builder.setInsertionPointToEnd(allocBlock);
@@ -556,18 +555,18 @@ static void createReductionAllocAndInitRegions(
   yield(boxAlloca);
 }
 
-template <typename OpType>
-OpType ReductionProcessor::createDeclareReduction(
+template <typename DeclareRedType>
+DeclareRedType ReductionProcessor::createDeclareReductionHelper(
     AbstractConverter &converter, llvm::StringRef reductionOpName,
-    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
-    bool isByRef) {
+    mlir::Type type, mlir::Location loc, bool isByRef,
+    GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   mlir::OpBuilder::InsertionGuard guard(builder);
   mlir::ModuleOp module = builder.getModule();
 
   assert(!reductionOpName.empty());
 
-  auto decl = module.lookupSymbol<OpType>(reductionOpName);
+  auto decl = module.lookupSymbol<DeclareRedType>(reductionOpName);
   if (decl)
     return decl;
 
@@ -576,23 +575,54 @@ OpType ReductionProcessor::createDeclareReduction(
   if (!isByRef)
     type = valTy;
 
-  decl = OpType::create(modBuilder, loc, reductionOpName, type);
-  createReductionAllocAndInitRegions(converter, loc, decl, redId, type,
+  decl = DeclareRedType::create(modBuilder, loc, reductionOpName, type);
+  createReductionAllocAndInitRegions(converter, loc, decl, genInitValueCB, type,
                                      isByRef);
-
   builder.createBlock(&decl.getReductionRegion(),
                       decl.getReductionRegion().end(), {type, type},
                       {loc, loc});
-
   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-  genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
-
+  genCombinerCB(builder, loc, type, op1, op2, isByRef);
   return decl;
 }
 
-static bool doReductionByRef(mlir::Value reductionVar) {
+template <typename OpType>
+OpType ReductionProcessor::createDeclareReduction(
+    AbstractConverter &converter, llvm::StringRef reductionOpName,
+    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
+    bool isByRef) {
+  auto genInitValueCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
+                            mlir::Type type, mlir::Value val) {
+    mlir::Type ty = fir::unwrapRefType(type);
+    mlir::Value initValue = ReductionProcessor::getReductionInitValue(
+        loc, unwrapSeqOrBoxedType(ty), redId, builder);
+    return initValue;
+  };
+  auto genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
+                           mlir::Type type, mlir::Value op1, mlir::Value op2,
+                           bool isByRef) {
+    genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
+  };
+
+  return createDeclareReductionHelper<OpType>(converter, reductionOpName, type,
+                                              loc, isByRef, genCombinerCB,
+                                              genInitValueCB);
+}
+
+bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) {
+  if (forceByrefReduction)
+    return true;
+
+  if (!fir::isa_trivial(fir::unwrapRefType(reductionType)) &&
+      !fir::isa_derived(fir::unwrapRefType(reductionType)))
+    return true;
+
+    return false;
+}
+
+bool ReductionProcessor::doReductionByRef(mlir::Value reductionVar) {
   if (forceByrefReduction)
     return true;
 
@@ -600,10 +630,7 @@ static bool doReductionByRef(mlir::Value reductionVar) {
           mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
     reductionVar = declare.getMemref();
 
-  if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
-    return true;
-
-  return false;
+  return doReductionByRef(reductionVar.getType());
 }
 
 template <typename OpType, typename RedOperatorListTy>
@@ -614,6 +641,8 @@ bool ReductionProcessor::processReductionArguments(
     llvm::SmallVectorImpl<bool> &reduceVarByRef,
     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
     const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+
   if constexpr (std::is_same_v<RedOperatorListTy,
                                omp::clause::ReductionOperatorList>) {
     // For OpenMP reduction clauses, check if the reduction operator is
@@ -627,7 +656,13 @@ bool ReductionProcessor::processReductionArguments(
               std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
         if (!ReductionProcessor::supportedIntrinsicProcReduction(
                 *reductionIntrinsic)) {
-          return false;
+          // If not an intrinsic is has to be a custom reduction op, and should
+          // be available in the module.
+          semantics::Symbol *sym = reductionIntrinsic->v.sym();
+          mlir::ModuleOp module = builder.getModule();
+          auto decl = module.lookupSymbol<OpType>(getRealName(sym).ToString());
+          if (!decl)
+            return false;
         }
       } else {
         return false;
@@ -637,7 +672,6 @@ bool ReductionProcessor::processReductionArguments(
 
   // Reduction variable processing common to both intrinsic operators and
   // procedure designators
-  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   mlir::OpBuilder::InsertPoint dcIP;
   constexpr bool isDoConcurrent =
       std::is_same_v<OpType, fir::DeclareReductionOp>;
@@ -741,7 +775,13 @@ bool ReductionProcessor::processReductionArguments(
                          &redOperator.u)) {
         if (!ReductionProcessor::supportedIntrinsicProcReduction(
                 *reductionIntrinsic)) {
-          TODO(currentLocation, "Unsupported intrinsic proc reduction");
+          // Custom reductions we can just add to the symbols without
+          // generating the declare reduction op.
+          semantics::Symbol *sym = reductionIntrinsic->v.sym();
+          reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+              builder.getContext(), sym->name().ToString()));
+          ++idx;
+          continue;
         }
         redId = getReductionType(*reductionIntrinsic);
         reductionName =
diff --git a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
index 0b0e6bd9ecf34..1bd1cd6b05fff 100644
--- a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
+++ b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 namespace flangomp {
 #define GEN_PASS_DEF_MARKDECLARETARGETPASS
@@ -31,9 +32,95 @@ namespace {
 class MarkDeclareTargetPass
     : public flangomp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
 
-  void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
-                       mlir::omp::DeclareTargetCaptureClause parentCapClause,
-                       bool parentAutomap, mlir::Operation *currOp,
+  struct ParentInfo {
+    mlir::omp::DeclareTargetDeviceType devTy;
+    mlir::omp::DeclareTargetCaptureClause capClause;
+    bool automap;
+  };
+
+  void processSymbolRef(mlir::SymbolRefAttr symRef, ParentInfo parentInfo,
+                        llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
+    if (auto currFOp =
+            getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
+      auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
+          currFOp.getOperation());
+
+      if (current.isDeclareTarget()) {
+        auto currentDt = current.getDeclareTargetDeviceType();
+
+        // Found the same function twice, with different device_types,
+        // mark as Any as it belongs to both
+        if (currentDt != parentInfo.devTy &&
+            currentDt != mlir::omp::DeclareTargetDeviceType::any) {
+          current.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
+                                   current.getDeclareTargetCaptureClause(),
+                                   current.getDeclareTargetAutomap());
+        }
+      } else {
+        current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
+                                 parentInfo.automap);
+      }
+
+      markNestedFuncs(parentInfo, currFOp, visited);
+    }
+  }
+
+  void
+  processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
+                             ParentInfo parentInfo,
+                             llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
+    if (!symRefs)
+      return;
+
+    for (auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) {
+      if (auto declareReductionOp =
+              getOperation().lookupSymbol<mlir::omp::DeclareReductionOp>(
+                  symRef)) {
+        markNestedFuncs(parentInfo, declareReductionOp, visited);
+      }
+    }
+  }
+
+  void
+  processReductionClauses(mlir::Operation *op,
+                          ParentInfo parentInfo,
+                          llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
+    llvm::TypeSwitch<mlir::Operation &>(*op)
+      .Case([&](mlir::omp::LoopOp op) {
+          processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+        })
+        .Case([&](mlir::omp::ParallelOp op) {
+          processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+        })
+        .Case([&](mlir::omp::SectionsOp op) {
+          processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+        })
+        .Case([&](mlir::omp::SimdOp op) {
+          processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+        })
+        .Case([&](mlir::omp::TargetOp op) {
+          processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
+        })
+        .Case([&](mlir::omp::TaskgroupOp op) {
+          processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited);
+        })
+        .Case([&](mlir::omp::TaskloopOp op) {
+          processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+          processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
+        })
+        .Case([&](mlir::omp::TaskOp op) {
+          processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
+        })
+        .Case([&](mlir::omp::TeamsOp op) {
+          processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+        })
+        .Case([&](mlir::omp::WsloopOp op) {
+          processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+        })
+        .Default([](mlir::Operation &) {});
+  }
+
+  void markNestedFuncs(ParentInfo parentInfo, mlir::Operation *currOp,
                        llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
     if (visited.contains(currOp))
       return;
@@ -43,33 +130,10 @@ class MarkDeclareTargetPass
       if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) {
         if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
                 callOp.getCallableForCallee())) {
-          if (auto currFOp =
-                  getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
-            auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
-                currFOp.getOperation());
-
-            if (current.isDeclareTarget()) {
-              auto currentDt = current.getDeclareTargetDeviceType();
-
-              // Found the same function twice, with different device_types,
-              // mark as Any as it belongs to both
-              if (currentDt != parentDevTy &&
-                  currentDt != mlir::omp::DeclareTargetDeviceType::any) {
-                current.setDeclareTarget(
-                    mlir::omp::DeclareTargetDeviceType::any,
-                    current.getDeclareTargetCaptureClause(),
-                    current.getDeclareTargetAutomap());
-              }
-            } else {
-              current.setDeclareTarget(parentDevTy, parentCapClause,
-                                       parentAutomap);
-            }
-
-            markNestedFuncs(parentDevTy, parentCapClause, parentAutomap,
-                            currFOp, visited);
-          }
+          processSymbolRef(symRef, parentInfo, visited);
         }
       }
+      processReductionClauses(op, parentInfo, visited);
     });
   }
 
@@ -82,10 +146,10 @@ class MarkDeclareTargetPass
           functionOp.getOperation());
       if (declareTargetOp.isDeclareTarget()) {
         llvm::SmallPtrSet<mlir::Operation *, 16> visited;
-        markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(),
-                        declareTargetOp.getDeclareTargetCaptureClause(),
-                        declareTargetOp.getDeclareTargetAutomap(), functionOp,
-                        visited);
+        ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
+                              declareTargetOp.getDeclareTargetCaptureClause(),
+                              declareTargetOp.getDeclareTargetAutomap()};
+        markNestedFuncs(parentInfo, functionOp, visited);
       }
     }
 
@@ -96,12 +160,13 @@ class MarkDeclareTargetPass
     // the contents of the device clause
     getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
       llvm::SmallPtrSet<mlir::Operation *, 16> visited;
-      markNestedFuncs(
-          /*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
-          /*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to,
-          /*parentAutomap=*/false, tarOp, visited);
+      ParentInfo parentInfo = {
+          /*devTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
+          /*capClause=*/mlir::omp::DeclareTargetCaptureClause::to,
+          /*automap=*/false,
+      };
+      markNestedFuncs(parentInfo, tarOp, visited);
     });
   }
 };
-
 } // namespace
diff --git a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90 b/flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90
deleted file mode 100644
index 30630465490b2..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90
+++ /dev/null
@@ -1,28 +0,0 @@
-! This test checks lowering of OpenMP declare reduction Directive, with initialization
-! via a subroutine. This functionality is currently not implemented.
-
-! RUN: not flang -fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
-
-!CHECK: not yet implemented: OpenMPDeclareReductionConstruct
-subroutine initme(x,n)
-  integer x,n
-  x=n
-end subroutine initme
-
-function func(x, n, init)
-  integer func
-  integer x(n)
-  integer res
-  interface
-     subroutine initme(x,n)
-       integer x,n
-     end subroutine initme
-  end interface
-!$omp declare reduction(red_add:integer(4):omp_out=omp_out+omp_in) initializer(initme(omp_priv,0))
-  res=init
-!$omp simd reduction(red_add:res)
-  do i=1,n
-     res=res+x(i)
-  enddo
-  func=res
-end function func
diff --git a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90 b/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90
deleted file mode 100644
index db50c9ac8ee9d..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90
+++ /dev/null
@@ -1,10 +0,0 @@
-! This test checks lowering of OpenMP declare reduction Directive.
-
-! RUN: not flang -fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
-
-subroutine declare_red()
-  integer :: my_var
-  !CHECK: not yet implemented: OpenMPDeclareReductionConstruct
-  !$omp declare reduction (my_red : integer : omp_out = omp_in) initializer (omp_priv = 0)
-  my_var = 0
-end subroutine declare_red
diff --git a/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90 b/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90
new file mode 100644
index 0000000000000..a3c38d7ba0a25
--- /dev/null
+++ b/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90
@@ -0,0 +1,37 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s --check-prefixes ALL
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes ALL
+
+program main
+    use, intrinsic ::  iso_c_binding
+    implicit none
+    interface
+    subroutine myinit(priv, orig) bind(c,name="myinit")
+        use, intrinsic :: iso_c_binding
+        implicit none
+        integer::priv, orig
+    end subroutine myinit
+
+    function mycombine(lhs, rhs) bind(c,name="mycombine")
+        use, intrinsic :: iso_c_binding
+        implicit none
+        integer::lhs, rhs, mycombine
+    end function mycombine
+ end interface
+     !$omp declare reduction(myreduction:integer:omp_out = mycombine(omp_out, omp_in)) initializer(myinit(omp_priv, omp_orig))
+
+    integer :: i, s, a(10)
+    !$omp target
+    s = 0
+    !$omp do reduction(myreduction:s)
+    do i = 1, 10
+       s = mycombine(s, a(i))
+    enddo
+    !$omp end do
+    !$omp end target
+ end program main
+
+!ALL-LABEL: func.func {{.*}} @myinit(!fir.ref<i32>, !fir.ref<i32>)
+ !ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
+ !ALL-LABEL: func.func {{.*}} @mycombine(!fir.ref<i32>, !fir.ref<i32>)
+!ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
+
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
new file mode 100644
index 0000000000000..d544c0db488f0
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
@@ -0,0 +1,112 @@
+! This test checks lowering of OpenMP declare reduction Directive, with initialization
+! via a subroutine. This functionality is currently not implemented.
+
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+module maxtype_mod
+  implicit none
+
+  type maxtype
+     integer::sumval
+     integer::maxval
+  end type maxtype
+
+contains
+
+  subroutine initme(x,n)
+    type(maxtype) :: x,n
+    x%sumval=0
+    x%maxval=0
+  end subroutine initme
+
+  function mycombine(lhs, rhs)
+    type(maxtype) :: lhs, rhs
+    type(maxtype) :: mycombine
+    mycombine%sumval = lhs%sumval + rhs%sumval
+    mycombine%maxval = max(lhs%maxval, rhs%maxval)
+  end function mycombine
+
+  function func(x, n, init)
+    type(maxtype) :: func
+    integer :: n, i
+    type(maxtype) :: x(n)
+    type(maxtype) :: init
+    type(maxtype) :: res
+!$omp declare reduction(red_add_max:maxtype:omp_out=mycombine(omp_out,omp_in)) initializer(initme(omp_priv,omp_orig))
+    res=init
+!$omp simd reduction(red_add_max:res)
+    do i=1,n
+       res=mycombine(res,x(i))
+    enddo
+    func=res
+  end function func
+
+end module maxtype_mod
+!CHECK:  omp.declare_reduction @red_add_max : [[MAXTYPE:.*]] init {
+!CHECK:  ^bb0(%[[ARGI_0:.*]]: [[MAXTYPE]]):
+!CHECK:    %[[OMP_PRIV:.*]] = fir.alloca [[MAXTYPE]]
+!CHECK:    %[[OMP_ORIG:.*]] = fir.alloca [[MAXTYPE]]
+!CHECK:    fir.store %[[ARGI_0]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]>
+!CHECK:    %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    fir.store %[[ARGI_0]] to %[[OMP_PRIV]] : !fir.ref<[[MAXTYPE]]>
+!CHECK:    %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
+!CHECK:    %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
+!CHECK:    omp.yield(%[[OMP_PRIV_VAL]] : [[MAXTYPE]])
+!CHECK:  } combiner {
+!CHECK:  ^bb0(%[[ARGC_0:.*]]: [[MAXTYPE]], %[[ARGC_1:.*]]: [[MAXTYPE]]):
+!CHECK:    %[[RESULT:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = ".result"}
+!CHECK:    %[[OMP_OUT:.*]] = fir.alloca [[MAXTYPE]]
+!CHECK:    %[[OMP_IN:.*]] = fir.alloca [[MAXTYPE]]
+!CHECK:    fir.store %[[ARGC_1]] to %[[OMP_IN]] : !fir.ref<[[MAXTYPE]]>
+!CHECK:    %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    fir.store %[[ARGC_0]] to %[[OMP_OUT]] : !fir.ref<[[MAXTYPE]]>
+!CHECK:    %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
+!CHECK:    fir.save_result %[[COMBINE_RESULT]] to %[[RESULT]] : [[MAXTYPE]], !fir.ref<[[MAXTYPE]]>
+!CHECK:    %[[TMPRESULT:.*]]:2 = hlfir.declare %[[RESULT]] {uniq_name = ".tmp.func_result"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %false = arith.constant false
+!CHECK:    %[[EXPRRESULT:.*]] = hlfir.as_expr %[[TMPRESULT]]#0 move %false : (!fir.ref<[[MAXTYPE]]>, i1) -> !hlfir.expr<[[MAXTYPE]]>
+!CHECK:    %[[ASSOCIATE:.*]]:3 = hlfir.associate %[[EXPRRESULT]] {adapt.valuebyref} : (!hlfir.expr<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>, i1)
+!CHECK:    %[[RESULT_VAL:.*]] = fir.load %[[ASSOCIATE]]#0 : !fir.ref<[[MAXTYPE]]>
+!CHECK:    hlfir.end_associate %[[ASSOCIATE]]#1, %[[ASSOCIATE]]#2 : !fir.ref<[[MAXTYPE]]>, i1
+!CHECK:    omp.yield(%[[RESULT_VAL]] : [[MAXTYPE]])
+!CHECK:  }
+
+!CHECK:  func.func @_QMmaxtype_modPinitme(%arg0: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "x"}, %arg1: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "n"}) {
+!CHECK:    %0 = fir.dummy_scope : !fir.dscope
+!CHECK:    %1:2 = hlfir.declare %arg1 dummy_scope %0 arg 2 {uniq_name = "_QMmaxtype_modFinitmeEn"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %[[X_DECL:.*]]:2 = hlfir.declare %arg0 dummy_scope %0 arg 1 {uniq_name = "_QMmaxtype_modFinitmeEx"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %[[ZERO_0:.*]] = arith.constant 0 : i32
+!CHECK:    %[[X_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"sumval"}   : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK:    hlfir.assign %[[ZERO_0]] to %[[X_DESIGNATE_SUMVAL]] : i32, !fir.ref<i32>
+!CHECK:    %[[ZERO_1:.*]] = arith.constant 0 : i32
+!CHECK:    %[[X_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"maxval"}   : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK:    hlfir.assign %[[ZERO_1]] to %[[X_DESIGNATE_MAXVAL]] : i32, !fir.ref<i32>
+!CHECK:    return
+!CHECK:  }
+
+
+!CHECK:  func.func @_QMmaxtype_modPmycombine(%[[LHS:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "lhs"}, %[[RHS:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "rhs"}) -> [[MAXTYPE]] {
+!CHECK:    %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
+!CHECK:    %[[LHS_DECL:.*]]:2 = hlfir.declare %[[LHS]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QMmaxtype_modFmycombineElhs"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %[[RESULT_ALLOC:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = "mycombine", uniq_name = "_QMmaxtype_modFmycombineEmycombine"}
+!CHECK:    %[[RESULT_DECL:.*]]:2 = hlfir.declare %[[RESULT_ALLOC]] {uniq_name = "_QMmaxtype_modFmycombineEmycombine"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %[[RHS_DECL:.*]]:2 = hlfir.declare %[[RHS]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QMmaxtype_modFmycombineErhs"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK:    %[[LHS_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"sumval"}   : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK:    %[[LHS_SUMVAL:.*]] = fir.load %[[LHS_DESIGNATE_SUMVAL]] : !fir.ref<i32>
+!CHECK:    %[[RHS_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"sumval"}   : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK:    %[[RHS_SUMVAL:.*]] = fir.load %[[RHS_DESIGNATE_SUMVAL]] : !fir.ref<i32>
+!CHECK:    %[[SUM:.*]] = arith.addi %[[LHS_SUMVAL]], %[[RHS_SUMVAL]] : i32
+!CHECK:    %[[RESULT_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"sumval"}   : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK:    hlfir.assign %[[SUM]] to %[[RESULT_DESIGNATE_SUMVAL]] : i32, !fir.ref<i32>
+!CHECK:    %[[LHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"maxval"}   : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK:    %[[LHS_MAXVAL:.*]] = fir.load %[[LHS_DESIGNATE_MAXVAL]] : !fir.ref<i32>
+!CHECK:    %[[RHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"maxval"}   : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK:    %[[RHS_MAXVAL:.*]] = fir.load %13 : !fir.ref<i32>
+!CHECK:    %[[CMP:.*]] = arith.cmpi sgt, %12, %14 : i32
+!CHECK:    %[[MAX_VAL:.*]] = arith.select %[[CMP]], %[[LHS_MAXVAL]], %[[RHS_MAXVAL]] : i32
+!CHECK:    %[[RESULT_DESIGNAGE_MAXVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"maxval"}   : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK:    hlfir.assign %[[MAX_VAL]] to %[[RESULT_DESIGNAGE_MAXVAL]] : i32, !fir.ref<i32>
+!CHECK:    %[[RESULT:.*]] = fir.load %[[RESULT_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
+!CHECK:    return %[[RESULT]] : [[MAXTYPE]]
+!CHECK:  }
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
new file mode 100644
index 0000000000000..2ff2499391c70
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
@@ -0,0 +1,59 @@
+! This test checks lowering of OpenMP declare reduction Directive, with initialization
+! via a subroutine. This functionality is currently not implemented.
+
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+
+subroutine initme(x,n)
+  integer x,n
+  x=0
+end subroutine initme
+
+function func(x, n, init)
+  integer func
+  integer x(n)
+  integer res
+  interface
+     subroutine initme(x,n)
+       integer x,n
+     end subroutine initme
+  end interface
+!CHECK:  omp.declare_reduction @red_add : i32 init {
+!CHECK: ^bb0(%[[ARGI_0:.*]]: i32):
+!CHECK:    %[[OMP_PRIV:.*]] = fir.alloca i32
+!CHECK:    %[[OMP_ORIG:.*]] = fir.alloca i32
+!CHECK:    fir.store %[[ARGI_0]] to %[[OMP_ORIG]] : !fir.ref<i32>
+!CHECK:    %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    fir.store %[[ARGI_0]] to %[[OMP_PRIV]] : !fir.ref<i32>
+!CHECK:    %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    fir.call @_QPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> ()
+!CHECK:    %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<i32>
+!CHECK:    omp.yield(%[[OMP_PRIV_VAL]] : i32)
+!CHECK:  } combiner {
+!CHECK:  ^bb0(%[[ARGC_0:.*]]: i32, %[[ARGC_1:.*]]: i32):
+!CHECK:    %[[OMP_OUT:.*]] = fir.alloca i32
+!CHECK:    %[[OMP_IN:.*]]1 = fir.alloca i32
+!CHECK:    fir.store %[[ARGC_1]] to %1 : !fir.ref<i32>
+!CHECK:    %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %1 {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    fir.store %[[ARGC_0]] to %0 : !fir.ref<i32>
+!CHECK:    %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %0 {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
+!CHECK:    %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref<i32>
+!CHECK:    %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32
+!CHECK:    omp.yield(%[[SUM]] : i32)
+!CHECK:  }
+!CHECK:  func.func @_QPinitme(%[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
+!CHECK:    %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
+!CHECK:    %[[N_DECL:.*]]:2 = hlfir.declare %[[N]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QFinitmeEn"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] dummy_scope %0 arg 1 {uniq_name = "_QFinitmeEx"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    %[[CONST_0:.*]] = arith.constant 0 : i32
+!CHECK:    hlfir.assign %[[CONST_0]] to %[[X_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK:    return
+!CHECK:  }
+!$omp declare reduction(red_add:integer(4):omp_out=omp_out+omp_in) initializer(initme(omp_priv,omp_orig))
+  res=init
+!$omp simd reduction(red_add:res)
+  do i=1,n
+     res=res+x(i)
+  enddo
+  func=res
+end function func
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction.f90
new file mode 100644
index 0000000000000..107a49cbd46fc
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction.f90
@@ -0,0 +1,33 @@
+! This test checks lowering of OpenMP declare reduction Directive.
+
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+
+subroutine declare_red()
+  integer :: my_var
+!CHECK: omp.declare_reduction @my_red : i32 init {
+!CHECK: ^bb0(%[[ARGI_0:.*]]: i32):
+!CHECK:    %[[OMP_PRIV:.*]] = fir.alloca i32
+!CHECK:    %[[OMP_ORIG:.*]] = fir.alloca i32
+!CHECK:    fir.store %[[ARGI_0]] to %[[OMP_ORIG]] : !fir.ref<i32>
+!CHECK:    %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    fir.store %[[ARGI_0]] to %[[OMP_PRIV]] : !fir.ref<i32>
+!CHECK:    %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    %[[CONST_0:.*]] = arith.constant 0 : i32
+!CHECK:    omp.yield(%[[CONST_0]] : i32)
+!CHECK: } combiner {
+!CHECK:  ^bb0(%[[ARGC_0:.*]]: i32, %[[ARGC_1:.*]]: i32):
+!CHECK:    %[[OMP_OUT:.*]] = fir.alloca i32
+!CHECK:    %[[OMP_IN:.*]]1 = fir.alloca i32
+!CHECK:    fir.store %[[ARGC_1]] to %1 : !fir.ref<i32>
+!CHECK:    %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %1 {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    fir.store %[[ARGC_0]] to %0 : !fir.ref<i32>
+!CHECK:    %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %0 {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
+!CHECK:    %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref<i32>
+!CHECK:    %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32
+!CHECK:    omp.yield(%[[SUM]] : i32)
+!CHECK: }
+
+  !$omp declare reduction (my_red : integer : omp_out = omp_out + omp_in) initializer (omp_priv = 0)
+  my_var = 0
+end subroutine declare_red
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8edec990eaaba..0dcf0cf17f4d7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1170,6 +1170,7 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
 template <typename T>
 static void
 mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
+                      llvm::IRBuilderBase &builder,
                       SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
                       DenseMap<Value, llvm::Value *> &reductionVariableMap,
                       unsigned i) {
@@ -1180,8 +1181,17 @@ mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
 
   mlir::Value mlirSource = loop.getReductionVars()[i];
   llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
-  assert(llvmSource && "lookup reduction var");
-  moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource);
+  llvm::Value *origVal = llvmSource;
+  // If a non-pointer value is expected, load the value from the source pointer.
+  if (!isa<LLVM::LLVMPointerType>(
+          reduction.getInitializerMoldArg().getType()) &&
+      isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
+    origVal =
+        builder.CreateLoad(moduleTranslation.convertType(
+                               reduction.getInitializerMoldArg().getType()),
+                           llvmSource, "omp_orig");
+  }
+  moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
 
   if (entry.getNumArguments() > 1) {
     llvm::Value *allocation =
@@ -1254,7 +1264,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
     SmallVector<llvm::Value *, 1> phis;
 
     // map block argument to initializer region
-    mapInitializationArgs(op, moduleTranslation, reductionDecls,
+    mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
                           reductionVariableMap, i);
 
     // TODO In some cases (specially on the GPU), the init regions may

>From 10903790dc0f98ee55e229e860ce332bfcdaa29a Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Mon, 17 Nov 2025 13:09:10 -0500
Subject: [PATCH 2/2] Remove extra newline.

---
 flang/lib/Lower/OpenMP/ClauseProcessor.h | 1 -
 1 file changed, 1 deletion(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 85ca9ecf9d98a..529b871330052 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -12,7 +12,6 @@
 #ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H
 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
 
-
 #include "ClauseFinder.h"
 #include "Utils.h"
 #include "flang/Lower/AbstractConverter.h"



More information about the Mlir-commits mailing list