[flang] [llvm] [mlir] [OpenMP][flang] Lowering of OpenMP custom reductions to MLIR (PR #168417)
Jan Leyonberg via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 18 11:55:49 PST 2025
https://github.com/jsjodin updated https://github.com/llvm/llvm-project/pull/168417
>From f8a24cbfad7db15a6197e48e8f1d2343a1d02869 Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Thu, 18 Sep 2025 10:50:04 -0400
Subject: [PATCH 1/6] [OpenMP][flang] Lowering of OpenMP custom reductions to
MLIR
This patch add support for lowering of custom reductions to MLIR.
---
.../flang/Lower/Support/ReductionProcessor.h | 18 +++
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 60 ++++++++
flang/lib/Lower/OpenMP/ClauseProcessor.h | 5 +
flang/lib/Lower/OpenMP/Clauses.cpp | 17 ++-
flang/lib/Lower/OpenMP/OpenMP.cpp | 133 ++++++++++++++++-
.../lib/Lower/Support/ReductionProcessor.cpp | 90 ++++++++----
.../Optimizer/OpenMP/MarkDeclareTarget.cpp | 139 +++++++++++++-----
.../Todo/omp-declare-reduction-initsub.f90 | 28 ----
.../OpenMP/Todo/omp-declare-reduction.f90 | 10 --
...are-target-deferred-marking-reductions.f90 | 37 +++++
.../omp-declare-reduction-derivedtype.f90 | 112 ++++++++++++++
.../OpenMP/omp-declare-reduction-initsub.f90 | 59 ++++++++
.../Lower/OpenMP/omp-declare-reduction.f90 | 33 +++++
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 +-
14 files changed, 650 insertions(+), 107 deletions(-)
delete mode 100644 flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90
delete mode 100644 flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90
create mode 100644 flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90
create mode 100644 flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
create mode 100644 flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
create mode 100644 flang/test/Lower/OpenMP/omp-declare-reduction.f90
diff --git a/flang/include/flang/Lower/Support/ReductionProcessor.h b/flang/include/flang/Lower/Support/ReductionProcessor.h
index 66f26b3b55630..bd0447360f089 100644
--- a/flang/include/flang/Lower/Support/ReductionProcessor.h
+++ b/flang/include/flang/Lower/Support/ReductionProcessor.h
@@ -40,6 +40,13 @@ namespace omp {
class ReductionProcessor {
public:
+ using GenInitValueCBTy =
+ std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Type type, mlir::Value ompOrig)>;
+ using GenCombinerCBTy = std::function<void(
+ fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
+ mlir::Value op1, mlir::Value op2, bool isByRef)>;
+
// TODO: Move this enumeration to the OpenMP dialect
enum ReductionIdentifier {
ID,
@@ -58,6 +65,9 @@ class ReductionProcessor {
IEOR
};
+ static bool doReductionByRef(mlir::Type reductionType);
+ static bool doReductionByRef(mlir::Value reductionVar);
+
static ReductionIdentifier
getReductionType(const omp::clause::ProcedureDesignator &pd);
@@ -109,6 +119,14 @@ class ReductionProcessor {
ReductionIdentifier redId,
mlir::Type type, mlir::Value op1,
mlir::Value op2);
+ /// Creates an OpenMP reduction declaration and inserts it into the provided
+ /// symbol table. The init and combiner regions are generated by the callback
+ /// functions genCombinerCB and genInitValueCB.
+ template <typename DeclareRedType>
+ static DeclareRedType createDeclareReductionHelper(
+ AbstractConverter &converter, llvm::StringRef reductionOpName,
+ mlir::Type type, mlir::Location loc, bool isByRef,
+ GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);
/// Creates an OpenMP reduction declaration and inserts it into the provided
/// symbol table. The declaration has a constant initializer with the neutral
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index e018a2d937435..fadfb29b07a28 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -13,6 +13,7 @@
#include "ClauseProcessor.h"
#include "Utils.h"
+#include "flang/Lower/ConvertCall.h"
#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/PFTBuilder.h"
@@ -402,6 +403,65 @@ bool ClauseProcessor::processInclusive(
return false;
}
+bool ClauseProcessor::processInitializer(
+ lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
+ ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
+ if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
+ genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Type type, mlir::Value ompOrig) {
+ lower::SymMapScope scope(symMap);
+ const parser::OmpInitializerExpression &iexpr = inp.v.v;
+ const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
+ const std::list<parser::OmpStylizedDeclaration> &declList =
+ std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
+ mlir::Value ompPrivVar;
+ for (const parser::OmpStylizedDeclaration &decl : declList) {
+ auto &name = std::get<parser::ObjectName>(decl.var.t);
+ assert(name.symbol && "Name does not have a symbol");
+ mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
+ fir::StoreOp::create(builder, loc, ompOrig, addr);
+ fir::FortranVariableFlagsEnum extraFlags = {};
+ fir::FortranVariableFlagsAttr attributes =
+ Fortran::lower::translateSymbolAttributes(builder.getContext(),
+ *name.symbol, extraFlags);
+ auto declareOp = hlfir::DeclareOp::create(
+ builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
+ 0, attributes);
+ if (name.ToString() == "omp_priv")
+ ompPrivVar = declareOp.getResult(0);
+ symMap.addVariableDefinition(*name.symbol, declareOp);
+ }
+ // Lower the expression/function call
+ lower::StatementContext stmtCtx;
+ mlir::Value result = common::visit(
+ common::visitors{
+ [&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
+ convertCallToHLFIR(loc, converter, procRef, std::nullopt,
+ symMap, stmtCtx);
+ auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar);
+ return privVal;
+ },
+ [&](const auto &expr) -> mlir::Value {
+ mlir::Value exprResult = fir::getBase(convertExprToValue(
+ loc, converter, clause->v, symMap, stmtCtx));
+ // Conversion can either give a value or a refrence to a value,
+ // we need to return the reduction type, so an optional load may
+ // be generated.
+ if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
+ exprResult.getType()))
+ if (ompPrivVar.getType() == refType)
+ exprResult = fir::LoadOp::create(builder, loc, exprResult);
+ return exprResult;
+ }},
+ clause->v.u);
+ stmtCtx.finalizeAndPop();
+ return result;
+ };
+ return true;
+ }
+ return false;
+}
+
bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index d524b4ddc8ac4..85ca9ecf9d98a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -12,12 +12,14 @@
#ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
+
#include "ClauseFinder.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/OpenMP/Clauses.h"
+#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
@@ -88,6 +90,9 @@ class ClauseProcessor {
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processInclusive(mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const;
+ bool processInitializer(
+ lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
+ ReductionProcessor::GenInitValueCBTy &genInitValueCB) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNogroup(mlir::omp::NogroupClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b1a3c3d3c5439..cf8d9a7ee6596 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -981,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp,
Initializer make(const parser::OmpClause::Initializer &inp,
semantics::SemanticsContext &semaCtx) {
- llvm_unreachable("Empty: initializer");
+ const parser::OmpInitializerExpression &iexpr = inp.v.v;
+ const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
+ const parser::OmpStylizedInstance::Instance &instance =
+ std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t);
+ if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
+ auto &expr = std::get<parser::Expr>(as->t);
+ return Initializer{makeExpr(expr, semaCtx)};
+ } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
+ if (call->typedCall) {
+ const auto &procRef = *call->typedCall;
+ semantics::SomeExpr evalProcRef{procRef};
+ return Initializer{evalProcRef};
+ }
+ } else {
+ llvm_unreachable("Unexpected initializer");
+ }
}
InReduction make(const parser::OmpClause::InReduction &inp,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f822fe3c8dd71..c4b174db8ac22 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -18,12 +18,15 @@
#include "Decomposer.h"
#include "Utils.h"
#include "flang/Common/idioms.h"
+#include "flang/Evaluate/type.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
+#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/StatementContext.h"
+#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
@@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// TODO: Add private syms and vars.
args.reduction.syms = reductionSyms;
args.reduction.vars = clauseOps.reductionVars;
-
return genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_teams)
@@ -3563,12 +3565,137 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
}
+static bool
+processReductionCombiner(lower::AbstractConverter &converter,
+ lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx,
+ const parser::OmpReductionSpecifier &specifier,
+ ReductionProcessor::GenCombinerCBTy &genCombinerCB) {
+ const auto &combinerExpression =
+ std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
+ .value();
+ const parser::OmpStylizedInstance &combinerInstance =
+ combinerExpression.v.front();
+ const parser::OmpStylizedInstance::Instance &instance =
+ std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
+ if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
+ auto &expr = std::get<parser::Expr>(as->t);
+ genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Type type, mlir::Value lhs, mlir::Value rhs,
+ bool isByRef) {
+ const auto &evalExpr = makeExpr(expr, semaCtx);
+ lower::SymMapScope scope(symTable);
+ const std::list<parser::OmpStylizedDeclaration> &declList =
+ std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+ for (const parser::OmpStylizedDeclaration &decl : declList) {
+ auto &name = std::get<parser::ObjectName>(decl.var.t);
+ mlir::Value addr = lhs;
+ mlir::Type type = lhs.getType();
+ bool isRhs = name.ToString() == std::string("omp_in");
+ if (isRhs) {
+ addr = rhs;
+ type = rhs.getType();
+ }
+
+ assert(name.symbol && "Reduction object name does not have a symbol");
+ if (!fir::conformsWithPassByRef(type)) {
+ addr = builder.createTemporary(loc, type);
+ fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
+ }
+ fir::FortranVariableFlagsEnum extraFlags = {};
+ fir::FortranVariableFlagsAttr attributes =
+ Fortran::lower::translateSymbolAttributes(builder.getContext(),
+ *name.symbol, extraFlags);
+ auto declareOp = hlfir::DeclareOp::create(
+ builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
+ 0, attributes);
+ symTable.addVariableDefinition(*name.symbol, declareOp);
+ }
+
+ lower::StatementContext stmtCtx;
+ mlir::Value result = fir::getBase(
+ convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
+ if (auto refType =
+ llvm::dyn_cast<fir::ReferenceType>(result.getType()))
+ if (lhs.getType() == refType.getElementType())
+ result = fir::LoadOp::create(builder, loc, result);
+ stmtCtx.finalizeAndPop();
+ if (isByRef) {
+ fir::StoreOp::create(builder, loc, result, lhs);
+ mlir::omp::YieldOp::create(builder, loc, lhs);
+ } else {
+ mlir::omp::YieldOp::create(builder, loc, result);
+ }
+
+ return result;
+ };
+ }
+ return true;
+}
+
+// Getting the type from a symbol compared to a DeclSpec is simpler since we do
+// not need to consider derived vs intrinsic types. Semantics is guaranteed to
+// generate these symbols.
+static mlir::Type
+getReductionType(lower::AbstractConverter &converter,
+ const parser::OmpReductionSpecifier &specifier) {
+ const auto &combinerExpression =
+ std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
+ .value();
+ const parser::OmpStylizedInstance &combinerInstance =
+ combinerExpression.v.front();
+ const std::list<parser::OmpStylizedDeclaration> &declList =
+ std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+ const parser::OmpStylizedDeclaration &decl = declList.front();
+ const auto &name = std::get<parser::ObjectName>(decl.var.t);
+ const auto &symbol = semantics::SymbolRef(*name.symbol);
+ mlir::Type reductionType = converter.genType(symbol);
+ return reductionType;
+}
+
static void genOMP(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
- if (!semaCtx.langOptions().OpenMPSimd)
- TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
+ if (!semaCtx.langOptions().OpenMPSimd) {
+ const parser::OmpArgumentList &args{
+ declareReductionConstruct.v.Arguments()};
+ const parser::OmpArgument &arg{args.v.front()};
+ const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
+
+ if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
+ TODO(converter.getCurrentLocation(),
+ "multiple types in declare target is not yet supported");
+
+ mlir::Type reductionType = getReductionType(converter, specifier);
+ ReductionProcessor::GenCombinerCBTy genCombinerCB;
+ processReductionCombiner(converter, symTable, semaCtx, specifier,
+ genCombinerCB);
+ const parser::OmpClauseList &initializer =
+ declareReductionConstruct.v.Clauses();
+ if (initializer.v.size() > 0) {
+ List<Clause> clauses = makeClauses(initializer, semaCtx);
+ ReductionProcessor::GenInitValueCBTy genInitValueCB;
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ const parser::OmpClause::Initializer &iclause{
+ std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
+ cp.processInitializer(symTable, iclause, genInitValueCB);
+ const auto &identifier =
+ std::get<parser::OmpReductionIdentifier>(specifier.t);
+ const auto &designator =
+ std::get<parser::ProcedureDesignator>(identifier.u);
+ const auto &reductionName = std::get<parser::Name>(designator.u);
+ bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+ ReductionProcessor::createDeclareReductionHelper<
+ mlir::omp::DeclareReductionOp>(
+ converter, reductionName.ToString(), reductionType,
+ converter.getCurrentLocation(), isByRef, genCombinerCB,
+ genInitValueCB);
+ } else {
+ TODO(converter.getCurrentLocation(),
+ "declare target without an initializer clause is not yet supported");
+ }
+ }
}
static void
diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index 605a5b6b20b94..283e5ea73c319 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -462,7 +462,7 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
bool isByRef) {
ty = fir::unwrapRefType(ty);
- if (fir::isa_trivial(ty)) {
+ if (fir::isa_trivial(ty) || fir::isa_derived(ty)) {
mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
@@ -501,7 +501,7 @@ static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
template <typename OpType>
static void createReductionAllocAndInitRegions(
AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl,
- const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
+ ReductionProcessor::GenInitValueCBTy genInitValueCB, mlir::Type type,
bool isByRef) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); };
@@ -523,9 +523,8 @@ static void createReductionAllocAndInitRegions(
mlir::Type ty = fir::unwrapRefType(type);
builder.setInsertionPointToEnd(initBlock);
- mlir::Value initValue = ReductionProcessor::getReductionInitValue(
- loc, unwrapSeqOrBoxedType(ty), redId, builder);
-
+ mlir::Value initValue =
+ genInitValueCB(builder, loc, ty, initBlock->getArgument(0));
if (isByRef) {
populateByRefInitAndCleanupRegions(
converter, loc, type, initValue, initBlock,
@@ -536,7 +535,7 @@ static void createReductionAllocAndInitRegions(
/*isDoConcurrent*/ std::is_same_v<OpType, fir::DeclareReductionOp>);
}
- if (fir::isa_trivial(ty)) {
+ if (fir::isa_trivial(ty) || fir::isa_derived(ty)) {
if (isByRef) {
// alloc region
builder.setInsertionPointToEnd(allocBlock);
@@ -556,18 +555,18 @@ static void createReductionAllocAndInitRegions(
yield(boxAlloca);
}
-template <typename OpType>
-OpType ReductionProcessor::createDeclareReduction(
+template <typename DeclareRedType>
+DeclareRedType ReductionProcessor::createDeclareReductionHelper(
AbstractConverter &converter, llvm::StringRef reductionOpName,
- const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
- bool isByRef) {
+ mlir::Type type, mlir::Location loc, bool isByRef,
+ GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::OpBuilder::InsertionGuard guard(builder);
mlir::ModuleOp module = builder.getModule();
assert(!reductionOpName.empty());
- auto decl = module.lookupSymbol<OpType>(reductionOpName);
+ auto decl = module.lookupSymbol<DeclareRedType>(reductionOpName);
if (decl)
return decl;
@@ -576,23 +575,54 @@ OpType ReductionProcessor::createDeclareReduction(
if (!isByRef)
type = valTy;
- decl = OpType::create(modBuilder, loc, reductionOpName, type);
- createReductionAllocAndInitRegions(converter, loc, decl, redId, type,
+ decl = DeclareRedType::create(modBuilder, loc, reductionOpName, type);
+ createReductionAllocAndInitRegions(converter, loc, decl, genInitValueCB, type,
isByRef);
-
builder.createBlock(&decl.getReductionRegion(),
decl.getReductionRegion().end(), {type, type},
{loc, loc});
-
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
- genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
-
+ genCombinerCB(builder, loc, type, op1, op2, isByRef);
return decl;
}
-static bool doReductionByRef(mlir::Value reductionVar) {
+template <typename OpType>
+OpType ReductionProcessor::createDeclareReduction(
+ AbstractConverter &converter, llvm::StringRef reductionOpName,
+ const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
+ bool isByRef) {
+ auto genInitValueCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Type type, mlir::Value val) {
+ mlir::Type ty = fir::unwrapRefType(type);
+ mlir::Value initValue = ReductionProcessor::getReductionInitValue(
+ loc, unwrapSeqOrBoxedType(ty), redId, builder);
+ return initValue;
+ };
+ auto genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Type type, mlir::Value op1, mlir::Value op2,
+ bool isByRef) {
+ genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
+ };
+
+ return createDeclareReductionHelper<OpType>(converter, reductionOpName, type,
+ loc, isByRef, genCombinerCB,
+ genInitValueCB);
+}
+
+bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) {
+ if (forceByrefReduction)
+ return true;
+
+ if (!fir::isa_trivial(fir::unwrapRefType(reductionType)) &&
+ !fir::isa_derived(fir::unwrapRefType(reductionType)))
+ return true;
+
+ return false;
+}
+
+bool ReductionProcessor::doReductionByRef(mlir::Value reductionVar) {
if (forceByrefReduction)
return true;
@@ -600,10 +630,7 @@ static bool doReductionByRef(mlir::Value reductionVar) {
mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
reductionVar = declare.getMemref();
- if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
- return true;
-
- return false;
+ return doReductionByRef(reductionVar.getType());
}
template <typename OpType, typename RedOperatorListTy>
@@ -614,6 +641,8 @@ bool ReductionProcessor::processReductionArguments(
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+
if constexpr (std::is_same_v<RedOperatorListTy,
omp::clause::ReductionOperatorList>) {
// For OpenMP reduction clauses, check if the reduction operator is
@@ -627,7 +656,13 @@ bool ReductionProcessor::processReductionArguments(
std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic)) {
- return false;
+ // If not an intrinsic is has to be a custom reduction op, and should
+ // be available in the module.
+ semantics::Symbol *sym = reductionIntrinsic->v.sym();
+ mlir::ModuleOp module = builder.getModule();
+ auto decl = module.lookupSymbol<OpType>(getRealName(sym).ToString());
+ if (!decl)
+ return false;
}
} else {
return false;
@@ -637,7 +672,6 @@ bool ReductionProcessor::processReductionArguments(
// Reduction variable processing common to both intrinsic operators and
// procedure designators
- fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::OpBuilder::InsertPoint dcIP;
constexpr bool isDoConcurrent =
std::is_same_v<OpType, fir::DeclareReductionOp>;
@@ -741,7 +775,13 @@ bool ReductionProcessor::processReductionArguments(
&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic)) {
- TODO(currentLocation, "Unsupported intrinsic proc reduction");
+ // Custom reductions we can just add to the symbols without
+ // generating the declare reduction op.
+ semantics::Symbol *sym = reductionIntrinsic->v.sym();
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ builder.getContext(), sym->name().ToString()));
+ ++idx;
+ continue;
}
redId = getReductionType(*reductionIntrinsic);
reductionName =
diff --git a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
index 0b0e6bd9ecf34..1bd1cd6b05fff 100644
--- a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
+++ b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
@@ -21,6 +21,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
namespace flangomp {
#define GEN_PASS_DEF_MARKDECLARETARGETPASS
@@ -31,9 +32,95 @@ namespace {
class MarkDeclareTargetPass
: public flangomp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
- void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
- mlir::omp::DeclareTargetCaptureClause parentCapClause,
- bool parentAutomap, mlir::Operation *currOp,
+ struct ParentInfo {
+ mlir::omp::DeclareTargetDeviceType devTy;
+ mlir::omp::DeclareTargetCaptureClause capClause;
+ bool automap;
+ };
+
+ void processSymbolRef(mlir::SymbolRefAttr symRef, ParentInfo parentInfo,
+ llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
+ if (auto currFOp =
+ getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
+ auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
+ currFOp.getOperation());
+
+ if (current.isDeclareTarget()) {
+ auto currentDt = current.getDeclareTargetDeviceType();
+
+ // Found the same function twice, with different device_types,
+ // mark as Any as it belongs to both
+ if (currentDt != parentInfo.devTy &&
+ currentDt != mlir::omp::DeclareTargetDeviceType::any) {
+ current.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
+ current.getDeclareTargetCaptureClause(),
+ current.getDeclareTargetAutomap());
+ }
+ } else {
+ current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
+ parentInfo.automap);
+ }
+
+ markNestedFuncs(parentInfo, currFOp, visited);
+ }
+ }
+
+ void
+ processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
+ ParentInfo parentInfo,
+ llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
+ if (!symRefs)
+ return;
+
+ for (auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) {
+ if (auto declareReductionOp =
+ getOperation().lookupSymbol<mlir::omp::DeclareReductionOp>(
+ symRef)) {
+ markNestedFuncs(parentInfo, declareReductionOp, visited);
+ }
+ }
+ }
+
+ void
+ processReductionClauses(mlir::Operation *op,
+ ParentInfo parentInfo,
+ llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
+ llvm::TypeSwitch<mlir::Operation &>(*op)
+ .Case([&](mlir::omp::LoopOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::ParallelOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::SectionsOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::SimdOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::TargetOp op) {
+ processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::TaskgroupOp op) {
+ processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::TaskloopOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::TaskOp op) {
+ processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::TeamsOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::WsloopOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Default([](mlir::Operation &) {});
+ }
+
+ void markNestedFuncs(ParentInfo parentInfo, mlir::Operation *currOp,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
if (visited.contains(currOp))
return;
@@ -43,33 +130,10 @@ class MarkDeclareTargetPass
if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) {
if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
callOp.getCallableForCallee())) {
- if (auto currFOp =
- getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
- auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
- currFOp.getOperation());
-
- if (current.isDeclareTarget()) {
- auto currentDt = current.getDeclareTargetDeviceType();
-
- // Found the same function twice, with different device_types,
- // mark as Any as it belongs to both
- if (currentDt != parentDevTy &&
- currentDt != mlir::omp::DeclareTargetDeviceType::any) {
- current.setDeclareTarget(
- mlir::omp::DeclareTargetDeviceType::any,
- current.getDeclareTargetCaptureClause(),
- current.getDeclareTargetAutomap());
- }
- } else {
- current.setDeclareTarget(parentDevTy, parentCapClause,
- parentAutomap);
- }
-
- markNestedFuncs(parentDevTy, parentCapClause, parentAutomap,
- currFOp, visited);
- }
+ processSymbolRef(symRef, parentInfo, visited);
}
}
+ processReductionClauses(op, parentInfo, visited);
});
}
@@ -82,10 +146,10 @@ class MarkDeclareTargetPass
functionOp.getOperation());
if (declareTargetOp.isDeclareTarget()) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
- markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(),
- declareTargetOp.getDeclareTargetCaptureClause(),
- declareTargetOp.getDeclareTargetAutomap(), functionOp,
- visited);
+ ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
+ declareTargetOp.getDeclareTargetCaptureClause(),
+ declareTargetOp.getDeclareTargetAutomap()};
+ markNestedFuncs(parentInfo, functionOp, visited);
}
}
@@ -96,12 +160,13 @@ class MarkDeclareTargetPass
// the contents of the device clause
getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
- markNestedFuncs(
- /*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
- /*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to,
- /*parentAutomap=*/false, tarOp, visited);
+ ParentInfo parentInfo = {
+ /*devTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
+ /*capClause=*/mlir::omp::DeclareTargetCaptureClause::to,
+ /*automap=*/false,
+ };
+ markNestedFuncs(parentInfo, tarOp, visited);
});
}
};
-
} // namespace
diff --git a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90 b/flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90
deleted file mode 100644
index 30630465490b2..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90
+++ /dev/null
@@ -1,28 +0,0 @@
-! This test checks lowering of OpenMP declare reduction Directive, with initialization
-! via a subroutine. This functionality is currently not implemented.
-
-! RUN: not flang -fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
-
-!CHECK: not yet implemented: OpenMPDeclareReductionConstruct
-subroutine initme(x,n)
- integer x,n
- x=n
-end subroutine initme
-
-function func(x, n, init)
- integer func
- integer x(n)
- integer res
- interface
- subroutine initme(x,n)
- integer x,n
- end subroutine initme
- end interface
-!$omp declare reduction(red_add:integer(4):omp_out=omp_out+omp_in) initializer(initme(omp_priv,0))
- res=init
-!$omp simd reduction(red_add:res)
- do i=1,n
- res=res+x(i)
- enddo
- func=res
-end function func
diff --git a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90 b/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90
deleted file mode 100644
index db50c9ac8ee9d..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90
+++ /dev/null
@@ -1,10 +0,0 @@
-! This test checks lowering of OpenMP declare reduction Directive.
-
-! RUN: not flang -fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
-
-subroutine declare_red()
- integer :: my_var
- !CHECK: not yet implemented: OpenMPDeclareReductionConstruct
- !$omp declare reduction (my_red : integer : omp_out = omp_in) initializer (omp_priv = 0)
- my_var = 0
-end subroutine declare_red
diff --git a/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90 b/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90
new file mode 100644
index 0000000000000..a3c38d7ba0a25
--- /dev/null
+++ b/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90
@@ -0,0 +1,37 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s --check-prefixes ALL
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes ALL
+
+program main
+ use, intrinsic :: iso_c_binding
+ implicit none
+ interface
+ subroutine myinit(priv, orig) bind(c,name="myinit")
+ use, intrinsic :: iso_c_binding
+ implicit none
+ integer::priv, orig
+ end subroutine myinit
+
+ function mycombine(lhs, rhs) bind(c,name="mycombine")
+ use, intrinsic :: iso_c_binding
+ implicit none
+ integer::lhs, rhs, mycombine
+ end function mycombine
+ end interface
+ !$omp declare reduction(myreduction:integer:omp_out = mycombine(omp_out, omp_in)) initializer(myinit(omp_priv, omp_orig))
+
+ integer :: i, s, a(10)
+ !$omp target
+ s = 0
+ !$omp do reduction(myreduction:s)
+ do i = 1, 10
+ s = mycombine(s, a(i))
+ enddo
+ !$omp end do
+ !$omp end target
+ end program main
+
+!ALL-LABEL: func.func {{.*}} @myinit(!fir.ref<i32>, !fir.ref<i32>)
+ !ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
+ !ALL-LABEL: func.func {{.*}} @mycombine(!fir.ref<i32>, !fir.ref<i32>)
+!ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
+
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
new file mode 100644
index 0000000000000..d544c0db488f0
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
@@ -0,0 +1,112 @@
+! This test checks lowering of OpenMP declare reduction Directive, with initialization
+! via a subroutine. This functionality is currently not implemented.
+
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+module maxtype_mod
+ implicit none
+
+ type maxtype
+ integer::sumval
+ integer::maxval
+ end type maxtype
+
+contains
+
+ subroutine initme(x,n)
+ type(maxtype) :: x,n
+ x%sumval=0
+ x%maxval=0
+ end subroutine initme
+
+ function mycombine(lhs, rhs)
+ type(maxtype) :: lhs, rhs
+ type(maxtype) :: mycombine
+ mycombine%sumval = lhs%sumval + rhs%sumval
+ mycombine%maxval = max(lhs%maxval, rhs%maxval)
+ end function mycombine
+
+ function func(x, n, init)
+ type(maxtype) :: func
+ integer :: n, i
+ type(maxtype) :: x(n)
+ type(maxtype) :: init
+ type(maxtype) :: res
+!$omp declare reduction(red_add_max:maxtype:omp_out=mycombine(omp_out,omp_in)) initializer(initme(omp_priv,omp_orig))
+ res=init
+!$omp simd reduction(red_add_max:res)
+ do i=1,n
+ res=mycombine(res,x(i))
+ enddo
+ func=res
+ end function func
+
+end module maxtype_mod
+!CHECK: omp.declare_reduction @red_add_max : [[MAXTYPE:.*]] init {
+!CHECK: ^bb0(%[[ARGI_0:.*]]: [[MAXTYPE]]):
+!CHECK: %[[OMP_PRIV:.*]] = fir.alloca [[MAXTYPE]]
+!CHECK: %[[OMP_ORIG:.*]] = fir.alloca [[MAXTYPE]]
+!CHECK: fir.store %[[ARGI_0]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]>
+!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: fir.store %[[ARGI_0]] to %[[OMP_PRIV]] : !fir.ref<[[MAXTYPE]]>
+!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
+!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
+!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : [[MAXTYPE]])
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARGC_0:.*]]: [[MAXTYPE]], %[[ARGC_1:.*]]: [[MAXTYPE]]):
+!CHECK: %[[RESULT:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = ".result"}
+!CHECK: %[[OMP_OUT:.*]] = fir.alloca [[MAXTYPE]]
+!CHECK: %[[OMP_IN:.*]] = fir.alloca [[MAXTYPE]]
+!CHECK: fir.store %[[ARGC_1]] to %[[OMP_IN]] : !fir.ref<[[MAXTYPE]]>
+!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: fir.store %[[ARGC_0]] to %[[OMP_OUT]] : !fir.ref<[[MAXTYPE]]>
+!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
+!CHECK: fir.save_result %[[COMBINE_RESULT]] to %[[RESULT]] : [[MAXTYPE]], !fir.ref<[[MAXTYPE]]>
+!CHECK: %[[TMPRESULT:.*]]:2 = hlfir.declare %[[RESULT]] {uniq_name = ".tmp.func_result"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %false = arith.constant false
+!CHECK: %[[EXPRRESULT:.*]] = hlfir.as_expr %[[TMPRESULT]]#0 move %false : (!fir.ref<[[MAXTYPE]]>, i1) -> !hlfir.expr<[[MAXTYPE]]>
+!CHECK: %[[ASSOCIATE:.*]]:3 = hlfir.associate %[[EXPRRESULT]] {adapt.valuebyref} : (!hlfir.expr<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>, i1)
+!CHECK: %[[RESULT_VAL:.*]] = fir.load %[[ASSOCIATE]]#0 : !fir.ref<[[MAXTYPE]]>
+!CHECK: hlfir.end_associate %[[ASSOCIATE]]#1, %[[ASSOCIATE]]#2 : !fir.ref<[[MAXTYPE]]>, i1
+!CHECK: omp.yield(%[[RESULT_VAL]] : [[MAXTYPE]])
+!CHECK: }
+
+!CHECK: func.func @_QMmaxtype_modPinitme(%arg0: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "x"}, %arg1: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "n"}) {
+!CHECK: %0 = fir.dummy_scope : !fir.dscope
+!CHECK: %1:2 = hlfir.declare %arg1 dummy_scope %0 arg 2 {uniq_name = "_QMmaxtype_modFinitmeEn"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %arg0 dummy_scope %0 arg 1 {uniq_name = "_QMmaxtype_modFinitmeEx"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[ZERO_0:.*]] = arith.constant 0 : i32
+!CHECK: %[[X_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK: hlfir.assign %[[ZERO_0]] to %[[X_DESIGNATE_SUMVAL]] : i32, !fir.ref<i32>
+!CHECK: %[[ZERO_1:.*]] = arith.constant 0 : i32
+!CHECK: %[[X_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK: hlfir.assign %[[ZERO_1]] to %[[X_DESIGNATE_MAXVAL]] : i32, !fir.ref<i32>
+!CHECK: return
+!CHECK: }
+
+
+!CHECK: func.func @_QMmaxtype_modPmycombine(%[[LHS:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "lhs"}, %[[RHS:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "rhs"}) -> [[MAXTYPE]] {
+!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
+!CHECK: %[[LHS_DECL:.*]]:2 = hlfir.declare %[[LHS]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QMmaxtype_modFmycombineElhs"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[RESULT_ALLOC:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = "mycombine", uniq_name = "_QMmaxtype_modFmycombineEmycombine"}
+!CHECK: %[[RESULT_DECL:.*]]:2 = hlfir.declare %[[RESULT_ALLOC]] {uniq_name = "_QMmaxtype_modFmycombineEmycombine"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[RHS_DECL:.*]]:2 = hlfir.declare %[[RHS]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QMmaxtype_modFmycombineErhs"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[LHS_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK: %[[LHS_SUMVAL:.*]] = fir.load %[[LHS_DESIGNATE_SUMVAL]] : !fir.ref<i32>
+!CHECK: %[[RHS_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK: %[[RHS_SUMVAL:.*]] = fir.load %[[RHS_DESIGNATE_SUMVAL]] : !fir.ref<i32>
+!CHECK: %[[SUM:.*]] = arith.addi %[[LHS_SUMVAL]], %[[RHS_SUMVAL]] : i32
+!CHECK: %[[RESULT_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK: hlfir.assign %[[SUM]] to %[[RESULT_DESIGNATE_SUMVAL]] : i32, !fir.ref<i32>
+!CHECK: %[[LHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK: %[[LHS_MAXVAL:.*]] = fir.load %[[LHS_DESIGNATE_MAXVAL]] : !fir.ref<i32>
+!CHECK: %[[RHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK: %[[RHS_MAXVAL:.*]] = fir.load %13 : !fir.ref<i32>
+!CHECK: %[[CMP:.*]] = arith.cmpi sgt, %12, %14 : i32
+!CHECK: %[[MAX_VAL:.*]] = arith.select %[[CMP]], %[[LHS_MAXVAL]], %[[RHS_MAXVAL]] : i32
+!CHECK: %[[RESULT_DESIGNAGE_MAXVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
+!CHECK: hlfir.assign %[[MAX_VAL]] to %[[RESULT_DESIGNAGE_MAXVAL]] : i32, !fir.ref<i32>
+!CHECK: %[[RESULT:.*]] = fir.load %[[RESULT_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
+!CHECK: return %[[RESULT]] : [[MAXTYPE]]
+!CHECK: }
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
new file mode 100644
index 0000000000000..2ff2499391c70
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
@@ -0,0 +1,59 @@
+! This test checks lowering of OpenMP declare reduction Directive, with initialization
+! via a subroutine. This functionality is currently not implemented.
+
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+
+subroutine initme(x,n)
+ integer x,n
+ x=0
+end subroutine initme
+
+function func(x, n, init)
+ integer func
+ integer x(n)
+ integer res
+ interface
+ subroutine initme(x,n)
+ integer x,n
+ end subroutine initme
+ end interface
+!CHECK: omp.declare_reduction @red_add : i32 init {
+!CHECK: ^bb0(%[[ARGI_0:.*]]: i32):
+!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32
+!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32
+!CHECK: fir.store %[[ARGI_0]] to %[[OMP_ORIG]] : !fir.ref<i32>
+!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: fir.store %[[ARGI_0]] to %[[OMP_PRIV]] : !fir.ref<i32>
+!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: fir.call @_QPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> ()
+!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<i32>
+!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARGC_0:.*]]: i32, %[[ARGC_1:.*]]: i32):
+!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32
+!CHECK: %[[OMP_IN:.*]]1 = fir.alloca i32
+!CHECK: fir.store %[[ARGC_1]] to %1 : !fir.ref<i32>
+!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %1 {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: fir.store %[[ARGC_0]] to %0 : !fir.ref<i32>
+!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %0 {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
+!CHECK: %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref<i32>
+!CHECK: %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32
+!CHECK: omp.yield(%[[SUM]] : i32)
+!CHECK: }
+!CHECK: func.func @_QPinitme(%[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
+!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
+!CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %[[N]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QFinitmeEn"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] dummy_scope %0 arg 1 {uniq_name = "_QFinitmeEx"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32
+!CHECK: hlfir.assign %[[CONST_0]] to %[[X_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK: return
+!CHECK: }
+!$omp declare reduction(red_add:integer(4):omp_out=omp_out+omp_in) initializer(initme(omp_priv,omp_orig))
+ res=init
+!$omp simd reduction(red_add:res)
+ do i=1,n
+ res=res+x(i)
+ enddo
+ func=res
+end function func
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction.f90
new file mode 100644
index 0000000000000..107a49cbd46fc
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction.f90
@@ -0,0 +1,33 @@
+! This test checks lowering of OpenMP declare reduction Directive.
+
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+
+subroutine declare_red()
+ integer :: my_var
+!CHECK: omp.declare_reduction @my_red : i32 init {
+!CHECK: ^bb0(%[[ARGI_0:.*]]: i32):
+!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32
+!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32
+!CHECK: fir.store %[[ARGI_0]] to %[[OMP_ORIG]] : !fir.ref<i32>
+!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: fir.store %[[ARGI_0]] to %[[OMP_PRIV]] : !fir.ref<i32>
+!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32
+!CHECK: omp.yield(%[[CONST_0]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARGC_0:.*]]: i32, %[[ARGC_1:.*]]: i32):
+!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32
+!CHECK: %[[OMP_IN:.*]]1 = fir.alloca i32
+!CHECK: fir.store %[[ARGC_1]] to %1 : !fir.ref<i32>
+!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %1 {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: fir.store %[[ARGC_0]] to %0 : !fir.ref<i32>
+!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %0 {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
+!CHECK: %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref<i32>
+!CHECK: %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32
+!CHECK: omp.yield(%[[SUM]] : i32)
+!CHECK: }
+
+ !$omp declare reduction (my_red : integer : omp_out = omp_out + omp_in) initializer (omp_priv = 0)
+ my_var = 0
+end subroutine declare_red
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8edec990eaaba..0dcf0cf17f4d7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1170,6 +1170,7 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
template <typename T>
static void
mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
+ llvm::IRBuilderBase &builder,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
DenseMap<Value, llvm::Value *> &reductionVariableMap,
unsigned i) {
@@ -1180,8 +1181,17 @@ mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
mlir::Value mlirSource = loop.getReductionVars()[i];
llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
- assert(llvmSource && "lookup reduction var");
- moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource);
+ llvm::Value *origVal = llvmSource;
+ // If a non-pointer value is expected, load the value from the source pointer.
+ if (!isa<LLVM::LLVMPointerType>(
+ reduction.getInitializerMoldArg().getType()) &&
+ isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
+ origVal =
+ builder.CreateLoad(moduleTranslation.convertType(
+ reduction.getInitializerMoldArg().getType()),
+ llvmSource, "omp_orig");
+ }
+ moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
if (entry.getNumArguments() > 1) {
llvm::Value *allocation =
@@ -1254,7 +1264,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
SmallVector<llvm::Value *, 1> phis;
// map block argument to initializer region
- mapInitializationArgs(op, moduleTranslation, reductionDecls,
+ mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
reductionVariableMap, i);
// TODO In some cases (specially on the GPU), the init regions may
>From 10903790dc0f98ee55e229e860ce332bfcdaa29a Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Mon, 17 Nov 2025 13:09:10 -0500
Subject: [PATCH 2/6] Remove extra newline.
---
flang/lib/Lower/OpenMP/ClauseProcessor.h | 1 -
1 file changed, 1 deletion(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 85ca9ecf9d98a..529b871330052 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -12,7 +12,6 @@
#ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
-
#include "ClauseFinder.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
>From c5aec84a9507b2495e1e4394be1639461e6d1e74 Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Mon, 17 Nov 2025 13:21:03 -0500
Subject: [PATCH 3/6] Fix formatting
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 16 ++++++++--------
flang/lib/Lower/Support/ReductionProcessor.cpp | 2 +-
flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp | 12 +++++-------
3 files changed, 14 insertions(+), 16 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c4b174db8ac22..eacc5dc939225 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3577,7 +3577,7 @@ processReductionCombiner(lower::AbstractConverter &converter,
const parser::OmpStylizedInstance &combinerInstance =
combinerExpression.v.front();
const parser::OmpStylizedInstance::Instance &instance =
- std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
+ std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
auto &expr = std::get<parser::Expr>(as->t);
genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
@@ -3586,7 +3586,8 @@ processReductionCombiner(lower::AbstractConverter &converter,
const auto &evalExpr = makeExpr(expr, semaCtx);
lower::SymMapScope scope(symTable);
const std::list<parser::OmpStylizedDeclaration> &declList =
- std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+ std::get<std::list<parser::OmpStylizedDeclaration>>(
+ combinerInstance.t);
for (const parser::OmpStylizedDeclaration &decl : declList) {
auto &name = std::get<parser::ObjectName>(decl.var.t);
mlir::Value addr = lhs;
@@ -3599,13 +3600,13 @@ processReductionCombiner(lower::AbstractConverter &converter,
assert(name.symbol && "Reduction object name does not have a symbol");
if (!fir::conformsWithPassByRef(type)) {
- addr = builder.createTemporary(loc, type);
- fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
+ addr = builder.createTemporary(loc, type);
+ fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
}
fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
- Fortran::lower::translateSymbolAttributes(builder.getContext(),
- *name.symbol, extraFlags);
+ Fortran::lower::translateSymbolAttributes(builder.getContext(),
+ *name.symbol, extraFlags);
auto declareOp = hlfir::DeclareOp::create(
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
0, attributes);
@@ -3615,8 +3616,7 @@ processReductionCombiner(lower::AbstractConverter &converter,
lower::StatementContext stmtCtx;
mlir::Value result = fir::getBase(
convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
- if (auto refType =
- llvm::dyn_cast<fir::ReferenceType>(result.getType()))
+ if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
if (lhs.getType() == refType.getElementType())
result = fir::LoadOp::create(builder, loc, result);
stmtCtx.finalizeAndPop();
diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index 283e5ea73c319..12c8b8f33c414 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -619,7 +619,7 @@ bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) {
!fir::isa_derived(fir::unwrapRefType(reductionType)))
return true;
- return false;
+ return false;
}
bool ReductionProcessor::doReductionByRef(mlir::Value reductionVar) {
diff --git a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
index 1bd1cd6b05fff..5fa77fb2080df 100644
--- a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
+++ b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
@@ -65,10 +65,9 @@ class MarkDeclareTargetPass
}
}
- void
- processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
- ParentInfo parentInfo,
- llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
+ void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
+ ParentInfo parentInfo,
+ llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
if (!symRefs)
return;
@@ -82,11 +81,10 @@ class MarkDeclareTargetPass
}
void
- processReductionClauses(mlir::Operation *op,
- ParentInfo parentInfo,
+ processReductionClauses(mlir::Operation *op, ParentInfo parentInfo,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
llvm::TypeSwitch<mlir::Operation &>(*op)
- .Case([&](mlir::omp::LoopOp op) {
+ .Case([&](mlir::omp::LoopOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::ParallelOp op) {
>From d914b85b68afd67e81da7a85f4caf43b468d0720 Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Mon, 17 Nov 2025 13:40:21 -0500
Subject: [PATCH 4/6] Fix
---
flang/lib/Lower/OpenMP/Clauses.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index cf8d9a7ee6596..dc49a8118b0a5 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -994,9 +994,9 @@ Initializer make(const parser::OmpClause::Initializer &inp,
semantics::SomeExpr evalProcRef{procRef};
return Initializer{evalProcRef};
}
- } else {
- llvm_unreachable("Unexpected initializer");
}
+
+ llvm_unreachable("Unexpected initializer");
}
InReduction make(const parser::OmpClause::InReduction &inp,
>From 271ad671e73add5f889abe263c597c8b29f4524a Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Tue, 18 Nov 2025 14:16:10 -0500
Subject: [PATCH 5/6] [Flang][OpenMP] Add offload runtime test for custom
reduction with derived types
---
.../target-custom-reduction-derivedtype.f90 | 88 +++++++++++++++++++
1 file changed, 88 insertions(+)
create mode 100644 offload/test/offloading/fortran/target-custom-reduction-derivedtype.f90
diff --git a/offload/test/offloading/fortran/target-custom-reduction-derivedtype.f90 b/offload/test/offloading/fortran/target-custom-reduction-derivedtype.f90
new file mode 100644
index 0000000000000..cc390cf0881f3
--- /dev/null
+++ b/offload/test/offloading/fortran/target-custom-reduction-derivedtype.f90
@@ -0,0 +1,88 @@
+! Basic offloading test with custom OpenMP reduction on derived type
+! REQUIRES: flang, amdgpu
+!
+! RUN: %libomptarget-compile-fortran-generic
+! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
+module maxtype_mod
+ implicit none
+
+ type maxtype
+ integer::sumval
+ integer::maxval
+ end type maxtype
+
+contains
+
+ subroutine initme(x,n)
+ type(maxtype) :: x,n
+ x%sumval=0
+ x%maxval=0
+ end subroutine initme
+
+ function mycombine(lhs, rhs)
+ type(maxtype) :: lhs, rhs
+ type(maxtype) :: mycombine
+ mycombine%sumval = lhs%sumval + rhs%sumval
+ mycombine%maxval = max(lhs%maxval, rhs%maxval)
+ end function mycombine
+
+end module maxtype_mod
+
+program main
+ use maxtype_mod
+ implicit none
+
+ integer :: n = 100
+ integer :: i
+ integer :: error = 0
+ type(maxtype) :: x(100)
+ type(maxtype) :: res
+ integer :: expected_sum, expected_max
+
+!$omp declare reduction(red_add_max:maxtype:omp_out=mycombine(omp_out,omp_in)) initializer(initme(omp_priv,omp_orig))
+
+ ! Initialize array with test data
+ do i = 1, n
+ x(i)%sumval = i
+ x(i)%maxval = i
+ end do
+
+ ! Initialize reduction variable
+ res%sumval = 0
+ res%maxval = 0
+
+ ! Perform reduction in target region
+ !$omp target parallel do map(to:x) reduction(red_add_max:res)
+ do i = 1, n
+ res = mycombine(res, x(i))
+ end do
+ !$omp end target parallel do
+
+ ! Compute expected values
+ expected_sum = 0
+ expected_max = 0
+ do i = 1, n
+ expected_sum = expected_sum + i
+ expected_max = max(expected_max, i)
+ end do
+
+ ! Check results
+ if (res%sumval /= expected_sum) then
+ error = 1
+ endif
+
+ if (res%maxval /= expected_max) then
+ error = 1
+ endif
+
+ if (error == 0) then
+ print *,"PASSED"
+ else
+ print *,"FAILED"
+ endif
+
+end program main
+
+! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
+! CHECK: PASSED
+
>From be3bb13740062d3bd23359ce58e7e6bfa8fdb899 Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Tue, 18 Nov 2025 14:54:57 -0500
Subject: [PATCH 6/6] Fix tests and comments.
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 5 ++--
...are-target-deferred-marking-reductions.f90 | 12 +++++-----
.../omp-declare-reduction-derivedtype.f90 | 24 +++++++++----------
.../OpenMP/omp-declare-reduction-initsub.f90 | 20 ++++++++--------
.../Lower/OpenMP/omp-declare-reduction.f90 | 18 +++++++-------
5 files changed, 40 insertions(+), 39 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index eacc5dc939225..2df960a8ccc4a 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3665,7 +3665,7 @@ static void genOMP(
if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
TODO(converter.getCurrentLocation(),
- "multiple types in declare target is not yet supported");
+ "multiple types in declare reduction is not yet supported");
mlir::Type reductionType = getReductionType(converter, specifier);
ReductionProcessor::GenCombinerCBTy genCombinerCB;
@@ -3693,7 +3693,8 @@ static void genOMP(
genInitValueCB);
} else {
TODO(converter.getCurrentLocation(),
- "declare target without an initializer clause is not yet supported");
+ "declare reduction without an initializer clause is not yet "
+ "supported");
}
}
}
diff --git a/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90 b/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90
index a3c38d7ba0a25..66697ef6bbe70 100644
--- a/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90
+++ b/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90
@@ -1,5 +1,5 @@
-!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s --check-prefixes ALL
-!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes ALL
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -fopenmp-is-device %s -o - | FileCheck %s
program main
use, intrinsic :: iso_c_binding
@@ -30,8 +30,8 @@ end function mycombine
!$omp end target
end program main
-!ALL-LABEL: func.func {{.*}} @myinit(!fir.ref<i32>, !fir.ref<i32>)
- !ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
- !ALL-LABEL: func.func {{.*}} @mycombine(!fir.ref<i32>, !fir.ref<i32>)
-!ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
+!CHECK: func.func {{.*}} @myinit(!fir.ref<i32>, !fir.ref<i32>)
+!CHECK-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
+!CHECK-LABEL: func.func {{.*}} @mycombine(!fir.ref<i32>, !fir.ref<i32>)
+!CHECK-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
index d544c0db488f0..36bb131e677a3 100644
--- a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
@@ -42,24 +42,24 @@ end function func
end module maxtype_mod
!CHECK: omp.declare_reduction @red_add_max : [[MAXTYPE:.*]] init {
-!CHECK: ^bb0(%[[ARGI_0:.*]]: [[MAXTYPE]]):
+!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: [[MAXTYPE]]):
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca [[MAXTYPE]]
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK: fir.store %[[ARGI_0]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]>
+!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]>
!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK: fir.store %[[ARGI_0]] to %[[OMP_PRIV]] : !fir.ref<[[MAXTYPE]]>
+!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<[[MAXTYPE]]>
!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : [[MAXTYPE]])
!CHECK: } combiner {
-!CHECK: ^bb0(%[[ARGC_0:.*]]: [[MAXTYPE]], %[[ARGC_1:.*]]: [[MAXTYPE]]):
+!CHECK: ^bb0(%[[LHS_ARG:.*]]: [[MAXTYPE]], %[[RHS_ARG:.*]]: [[MAXTYPE]]):
!CHECK: %[[RESULT:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = ".result"}
!CHECK: %[[OMP_OUT:.*]] = fir.alloca [[MAXTYPE]]
!CHECK: %[[OMP_IN:.*]] = fir.alloca [[MAXTYPE]]
-!CHECK: fir.store %[[ARGC_1]] to %[[OMP_IN]] : !fir.ref<[[MAXTYPE]]>
+!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<[[MAXTYPE]]>
!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK: fir.store %[[ARGC_0]] to %[[OMP_OUT]] : !fir.ref<[[MAXTYPE]]>
+!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<[[MAXTYPE]]>
!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
!CHECK: fir.save_result %[[COMBINE_RESULT]] to %[[RESULT]] : [[MAXTYPE]], !fir.ref<[[MAXTYPE]]>
@@ -72,10 +72,10 @@ end module maxtype_mod
!CHECK: omp.yield(%[[RESULT_VAL]] : [[MAXTYPE]])
!CHECK: }
-!CHECK: func.func @_QMmaxtype_modPinitme(%arg0: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "x"}, %arg1: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "n"}) {
-!CHECK: %0 = fir.dummy_scope : !fir.dscope
-!CHECK: %1:2 = hlfir.declare %arg1 dummy_scope %0 arg 2 {uniq_name = "_QMmaxtype_modFinitmeEn"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
-!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %arg0 dummy_scope %0 arg 1 {uniq_name = "_QMmaxtype_modFinitmeEx"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: func.func @_QMmaxtype_modPinitme(%[[X_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "x"}, %[[N_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "n"}) {
+!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
+!CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %[[N_ARG]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QMmaxtype_modFinitmeEn"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
+!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_ARG]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QMmaxtype_modFinitmeEx"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %[[ZERO_0:.*]] = arith.constant 0 : i32
!CHECK: %[[X_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: hlfir.assign %[[ZERO_0]] to %[[X_DESIGNATE_SUMVAL]] : i32, !fir.ref<i32>
@@ -102,8 +102,8 @@ end module maxtype_mod
!CHECK: %[[LHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: %[[LHS_MAXVAL:.*]] = fir.load %[[LHS_DESIGNATE_MAXVAL]] : !fir.ref<i32>
!CHECK: %[[RHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
-!CHECK: %[[RHS_MAXVAL:.*]] = fir.load %13 : !fir.ref<i32>
-!CHECK: %[[CMP:.*]] = arith.cmpi sgt, %12, %14 : i32
+!CHECK: %[[RHS_MAXVAL:.*]] = fir.load %[[RHS_DESIGNATE_MAXVAL]] : !fir.ref<i32>
+!CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS_MAXVAL]], %[[RHS_MAXVAL]] : i32
!CHECK: %[[MAX_VAL:.*]] = arith.select %[[CMP]], %[[LHS_MAXVAL]], %[[RHS_MAXVAL]] : i32
!CHECK: %[[RESULT_DESIGNAGE_MAXVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: hlfir.assign %[[MAX_VAL]] to %[[RESULT_DESIGNAGE_MAXVAL]] : i32, !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
index 2ff2499391c70..4aacc7cb2efba 100644
--- a/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
@@ -18,24 +18,24 @@ subroutine initme(x,n)
end subroutine initme
end interface
!CHECK: omp.declare_reduction @red_add : i32 init {
-!CHECK: ^bb0(%[[ARGI_0:.*]]: i32):
+!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32):
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32
-!CHECK: fir.store %[[ARGI_0]] to %[[OMP_ORIG]] : !fir.ref<i32>
+!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<i32>
!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK: fir.store %[[ARGI_0]] to %[[OMP_PRIV]] : !fir.ref<i32>
+!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<i32>
!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: fir.call @_QPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> ()
!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<i32>
!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : i32)
!CHECK: } combiner {
-!CHECK: ^bb0(%[[ARGC_0:.*]]: i32, %[[ARGC_1:.*]]: i32):
+!CHECK: ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32):
!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32
-!CHECK: %[[OMP_IN:.*]]1 = fir.alloca i32
-!CHECK: fir.store %[[ARGC_1]] to %1 : !fir.ref<i32>
-!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %1 {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK: fir.store %[[ARGC_0]] to %0 : !fir.ref<i32>
-!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %0 {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[OMP_IN:.*]] = fir.alloca i32
+!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<i32>
+!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<i32>
+!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32
@@ -44,7 +44,7 @@ end subroutine initme
!CHECK: func.func @_QPinitme(%[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
!CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %[[N]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QFinitmeEn"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] dummy_scope %0 arg 1 {uniq_name = "_QFinitmeEx"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] dummy_scope %[[OMP_OUT]] arg 1 {uniq_name = "_QFinitmeEx"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[CONST_0]] to %[[X_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: return
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction.f90
index 107a49cbd46fc..a41f6b214b9d8 100644
--- a/flang/test/Lower/OpenMP/omp-declare-reduction.f90
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction.f90
@@ -5,23 +5,23 @@
subroutine declare_red()
integer :: my_var
!CHECK: omp.declare_reduction @my_red : i32 init {
-!CHECK: ^bb0(%[[ARGI_0:.*]]: i32):
+!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32):
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32
-!CHECK: fir.store %[[ARGI_0]] to %[[OMP_ORIG]] : !fir.ref<i32>
+!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<i32>
!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK: fir.store %[[ARGI_0]] to %[[OMP_PRIV]] : !fir.ref<i32>
+!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<i32>
!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32
!CHECK: omp.yield(%[[CONST_0]] : i32)
!CHECK: } combiner {
-!CHECK: ^bb0(%[[ARGC_0:.*]]: i32, %[[ARGC_1:.*]]: i32):
+!CHECK: ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32):
!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32
-!CHECK: %[[OMP_IN:.*]]1 = fir.alloca i32
-!CHECK: fir.store %[[ARGC_1]] to %1 : !fir.ref<i32>
-!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %1 {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK: fir.store %[[ARGC_0]] to %0 : !fir.ref<i32>
-!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %0 {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[OMP_IN:.*]] = fir.alloca i32
+!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<i32>
+!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<i32>
+!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32
More information about the llvm-commits
mailing list