[flang] [llvm] [flang][OpenMP] Implement COMBINER clause (PR #172036)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 12 08:15:29 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Krzysztof Parzyszek (kparzysz)
<details>
<summary>Changes</summary>
This adds parsing and lowering of the COMBINER clause. It utilizes the existing lowering code for combiner-expression to lower the COMBINER clause as well.
---
Patch is 29.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172036.diff
13 Files Affected:
- (modified) flang/include/flang/Lower/OpenMP/Clauses.h (+5-1)
- (modified) flang/include/flang/Parser/dump-parse-tree.h (+1)
- (modified) flang/include/flang/Parser/openmp-utils.h (+3-3)
- (modified) flang/include/flang/Parser/parse-tree.h (+8)
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+6-4)
- (modified) flang/lib/Lower/OpenMP/Clauses.cpp (+31-18)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+67-73)
- (modified) flang/lib/Parser/openmp-parsers.cpp (+16-9)
- (modified) flang/lib/Parser/openmp-utils.cpp (+11-5)
- (modified) flang/lib/Semantics/check-omp-structure.cpp (+1)
- (added) flang/test/Parser/OpenMP/declare-reduction-combiner.f90 (+88)
- (modified) llvm/include/llvm/Frontend/OpenMP/ClauseT.h (+20-11)
- (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+4)
``````````diff
diff --git a/flang/include/flang/Lower/OpenMP/Clauses.h b/flang/include/flang/Lower/OpenMP/Clauses.h
index 455eda2738e6d..5f03877624be7 100644
--- a/flang/include/flang/Lower/OpenMP/Clauses.h
+++ b/flang/include/flang/Lower/OpenMP/Clauses.h
@@ -104,6 +104,7 @@ struct hash<Fortran::lower::omp::IdTy> {
namespace Fortran::lower::omp {
using Object = tomp::ObjectT<IdTy, ExprTy>;
using ObjectList = tomp::ObjectListT<IdTy, ExprTy>;
+using StylizedInstance = tomp::type::StylizedInstanceT<IdTy, ExprTy>;
Object makeObject(const parser::OmpObject &object,
semantics::SemanticsContext &semaCtx);
@@ -173,8 +174,10 @@ std::optional<ResultTy> maybeApplyToV(FuncTy &&func, const ArgTy *arg) {
std::optional<Object> getBaseObject(const Object &object,
semantics::SemanticsContext &semaCtx);
+StylizedInstance makeStylizedInstance(const parser::OmpStylizedInstance &inp,
+ semantics::SemanticsContext &semaCtx);
+
namespace clause {
-using StylizedInstance = tomp::type::StylizedInstanceT<IdTy, ExprTy>;
using Range = tomp::type::RangeT<ExprTy>;
using Mapper = tomp::type::MapperT<IdTy, ExprTy>;
using Iterator = tomp::type::IteratorT<TypeTy, IdTy, ExprTy>;
@@ -208,6 +211,7 @@ using Bind = tomp::clause::BindT<TypeTy, IdTy, ExprTy>;
using Capture = tomp::clause::CaptureT<TypeTy, IdTy, ExprTy>;
using Collapse = tomp::clause::CollapseT<TypeTy, IdTy, ExprTy>;
using Collector = tomp::clause::CollectorT<TypeTy, IdTy, ExprTy>;
+using Combiner = tomp::clause::CombinerT<TypeTy, IdTy, ExprTy>;
using Compare = tomp::clause::CompareT<TypeTy, IdTy, ExprTy>;
using Contains = tomp::clause::ContainsT<TypeTy, IdTy, ExprTy>;
using Copyin = tomp::clause::CopyinT<TypeTy, IdTy, ExprTy>;
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 252e156d2d459..dbc2b6541dd75 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -562,6 +562,7 @@ class ParseTreeDumper {
NODE(parser, OmpClauseList)
NODE(parser, OmpCloseModifier)
NODE_ENUM(OmpCloseModifier, Value)
+ NODE(parser, OmpCombinerClause)
NODE(parser, OmpCombinerExpression)
NODE(parser, OmpContainsClause)
NODE(parser, OmpContextSelectorSpecification)
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index 0fc7dbd29d6aa..bd200558e4c59 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -226,9 +226,9 @@ const BlockConstruct *GetFortranBlockConstruct(
const Block &GetInnermostExecPart(const Block &block);
bool IsStrictlyStructuredBlock(const Block &block);
-const OmpCombinerExpression *GetCombinerExpr(
- const OmpReductionSpecifier &rspec);
-const OmpInitializerExpression *GetInitializerExpr(const OmpClause &init);
+const OmpCombinerExpression *GetCombinerExpr(const OmpReductionSpecifier &x);
+const OmpCombinerExpression *GetCombinerExpr(const OmpClause &x);
+const OmpInitializerExpression *GetInitializerExpr(const OmpClause &x);
struct OmpAllocateInfo {
std::vector<const OmpAllocateDirective *> dirs;
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 93743709f10d2..b00d25373f801 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4395,6 +4395,14 @@ struct OmpCancellationConstructTypeClause {
std::tuple<OmpDirectiveName, std::optional<ScalarLogicalExpr>> t;
};
+// Ref: [6.0:262]
+//
+// combiner-clause -> // since 6.0
+// COMBINER(combiner-expr)
+struct OmpCombinerClause {
+ WRAPPER_CLASS_BOILERPLATE(OmpCombinerClause, OmpCombinerExpression);
+};
+
// Ref: [5.2:214]
//
// contains-clause ->
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 3c31b3a07f57f..b923e415231d6 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -390,10 +390,10 @@ bool ClauseProcessor::processInitializer(
mlir::Type type, mlir::Value ompOrig) {
lower::SymMapScope scope(symMap);
mlir::Value ompPrivVar;
- const clause::StylizedInstance &inst = clause->v.front();
+ const StylizedInstance &inst = clause->v.front();
for (const Object &object :
- std::get<clause::StylizedInstance::Variables>(inst.t)) {
+ std::get<StylizedInstance::Variables>(inst.t)) {
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
fir::StoreOp::create(builder, loc, ompOrig, addr);
fir::FortranVariableFlagsEnum extraFlags = {};
@@ -412,7 +412,7 @@ bool ClauseProcessor::processInitializer(
// Lower the expression/function call
lower::StatementContext stmtCtx;
const semantics::SomeExpr &initExpr =
- std::get<clause::StylizedInstance::Instance>(inst.t);
+ std::get<StylizedInstance::Instance>(inst.t);
mlir::Value result = common::visit(
common::visitors{
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
@@ -439,7 +439,9 @@ bool ClauseProcessor::processInitializer(
};
return true;
}
- return false;
+ TODO(converter.getCurrentLocation(),
+ "declare reduction without an initializer clause is not yet "
+ "supported");
}
bool ClauseProcessor::processMergeable(
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index 9ea4e8fcd6c0e..d53054f005dea 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -197,6 +197,24 @@ std::optional<Object> getBaseObject(const Object &object,
return std::nullopt;
}
+StylizedInstance makeStylizedInstance(const parser::OmpStylizedInstance &inp,
+ semantics::SemanticsContext &semaCtx) {
+ ObjectList variables;
+ llvm::transform(std::get<std::list<parser::OmpStylizedDeclaration>>(inp.t),
+ std::back_inserter(variables),
+ [&](const parser::OmpStylizedDeclaration &s) {
+ return makeObject(s.var, semaCtx);
+ });
+
+ SomeExpr instance = [&]() {
+ if (auto &&expr = semantics::omp::MakeEvaluateExpr(inp))
+ return std::move(*expr);
+ llvm_unreachable("Expecting expression instance");
+ }();
+
+ return StylizedInstance{{std::move(variables), std::move(instance)}};
+}
+
// Helper macros
#define MAKE_EMPTY_CLASS(cls, from_cls) \
cls make(const parser::OmpClause::from_cls &, \
@@ -551,6 +569,17 @@ Collapse make(const parser::OmpClause::Collapse &inp,
return Collapse{/*N=*/makeExpr(inp.v, semaCtx)};
}
+Combiner make(const parser::OmpClause::Combiner &inp,
+ semantics::SemanticsContext &semaCtx) {
+ const parser::OmpCombinerExpression &cexpr = inp.v.v;
+ Combiner combiner;
+
+ for (const parser::OmpStylizedInstance &sinst : cexpr.v)
+ combiner.v.push_back(makeStylizedInstance(sinst, semaCtx));
+
+ return combiner;
+}
+
// Compare: empty
Contains make(const parser::OmpClause::Contains &inp,
@@ -988,24 +1017,8 @@ Initializer make(const parser::OmpClause::Initializer &inp,
const parser::OmpInitializerExpression &iexpr = inp.v.v;
Initializer initializer;
- for (const parser::OmpStylizedInstance &sinst : iexpr.v) {
- ObjectList variables;
- llvm::transform(
- std::get<std::list<parser::OmpStylizedDeclaration>>(sinst.t),
- std::back_inserter(variables),
- [&](const parser::OmpStylizedDeclaration &s) {
- return makeObject(s.var, semaCtx);
- });
-
- SomeExpr instance = [&]() {
- if (auto &&expr = semantics::omp::MakeEvaluateExpr(sinst))
- return std::move(*expr);
- llvm_unreachable("Expecting expression instance");
- }();
-
- initializer.v.push_back(
- StylizedInstance{{std::move(variables), std::move(instance)}});
- }
+ for (const parser::OmpStylizedInstance &sinst : iexpr.v)
+ initializer.v.push_back(makeStylizedInstance(sinst, semaCtx));
return initializer;
}
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 38ab42076f559..7965119764e5d 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3602,57 +3602,28 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
}
-static ReductionProcessor::GenCombinerCBTy
-processReductionCombiner(lower::AbstractConverter &converter,
- lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx,
- const parser::OmpReductionSpecifier &specifier) {
+static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
+ lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, const clause::Combiner &combiner) {
ReductionProcessor::GenCombinerCBTy genCombinerCB;
- const auto &combinerExpression =
- std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
- .value();
- const parser::OmpStylizedInstance &combinerInstance =
- combinerExpression.v.front();
- const parser::OmpStylizedInstance::Instance &instance =
- std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
-
- std::optional<semantics::SomeExpr> evalExprOpt;
- if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
- auto &expr = std::get<parser::Expr>(as->t);
- evalExprOpt = makeExpr(expr, semaCtx);
- } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
- if (call->typedCall) {
- const auto &procRef = *call->typedCall;
- evalExprOpt = semantics::SomeExpr{procRef};
- } else {
- TODO(converter.getCurrentLocation(),
- "CallStmt without typedCall is not yet supported");
- }
- } else {
- TODO(converter.getCurrentLocation(), "Unsupported combiner instance type");
- }
-
- assert(evalExprOpt.has_value() && "evalExpr must be initialized");
- semantics::SomeExpr evalExpr = *evalExprOpt;
+ const StylizedInstance &inst = combiner.v.front();
+ semantics::SomeExpr evalExpr = std::get<StylizedInstance::Instance>(inst.t);
genCombinerCB = [&, evalExpr](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value lhs,
mlir::Value rhs, bool isByRef) {
lower::SymMapScope scope(symTable);
- const std::list<parser::OmpStylizedDeclaration> &declList =
- std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
mlir::Value ompOutVar;
- for (const parser::OmpStylizedDeclaration &decl : declList) {
- auto &name = std::get<parser::ObjectName>(decl.var.t);
+ for (const Object &object : std::get<StylizedInstance::Variables>(inst.t)) {
mlir::Value addr = lhs;
mlir::Type type = lhs.getType();
- bool isRhs = name.ToString() == std::string("omp_in");
+ std::string name = object.sym()->name().ToString();
+ bool isRhs = name == "omp_in";
if (isRhs) {
addr = rhs;
type = rhs.getType();
}
- assert(name.symbol && "Reduction object name does not have a symbol");
if (!fir::conformsWithPassByRef(type)) {
addr = builder.createTemporary(loc, type);
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
@@ -3660,13 +3631,13 @@ processReductionCombiner(lower::AbstractConverter &converter,
fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(builder.getContext(),
- *name.symbol, extraFlags);
+ *object.sym(), extraFlags);
auto declareOp =
- hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
- {}, nullptr, nullptr, 0, attributes);
- if (name.ToString() == "omp_out")
+ hlfir::DeclareOp::create(builder, loc, addr, name, nullptr, {},
+ nullptr, nullptr, 0, attributes);
+ if (name == "omp_out")
ompOutVar = declareOp.getResult(0);
- symTable.addVariableDefinition(*name.symbol, declareOp);
+ symTable.addVariableDefinition(*object.sym(), declareOp);
}
lower::StatementContext stmtCtx;
@@ -3740,46 +3711,69 @@ getReductionType(lower::AbstractConverter &converter,
return reductionType;
}
-static void genOMP(
- lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
- const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
+// Represent the reduction combiner as a clause, return reference to it.
+// If there is a "combiner" clause already present, do nothing. Otherwise
+// manufacture a combiner clause from the combiner expression on the reduction
+// specifier and append it to the list of clauses.
+static const clause::Combiner &
+appendCombiner(const parser::OpenMPDeclareReductionConstruct &construct,
+ List<Clause> &clauses, semantics::SemanticsContext &semaCtx) {
+ for (const Clause &clause : clauses) {
+ if (clause.id == llvm::omp::Clause::OMPC_combiner)
+ return std::get<clause::Combiner>(clause.u);
+ }
+
+ using namespace parser::omp;
+ const parser::OmpDirectiveSpecification &dirSpec = construct.v;
+ auto *specifier = GetFirstArgument<parser::OmpReductionSpecifier>(dirSpec);
+ assert(specifier && "Expecting reduction specifier");
+ if (auto *expr = GetCombinerExpr(*specifier)) {
+ clause::Combiner combiner;
+ for (const parser::OmpStylizedInstance &sinst : expr->v)
+ combiner.v.push_back(makeStylizedInstance(sinst, semaCtx));
+ clauses.push_back(makeClause(llvm::omp::Clause::OMPC_combiner,
+ std::move(combiner), expr->source));
+ return std::get<clause::Combiner>(clauses.back().u);
+ }
+
+ llvm_unreachable("Expecting reduction combiner");
+}
+
+static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval,
+ const parser::OpenMPDeclareReductionConstruct &construct) {
if (semaCtx.langOptions().OpenMPSimd)
return;
- const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()};
- const parser::OmpArgument &arg{args.v.front()};
- const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
-
+ const auto &specifier =
+ DEREF(parser::omp::GetFirstArgument<parser::OmpReductionSpecifier>(
+ construct.v));
if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
TODO(converter.getCurrentLocation(),
"multiple types in declare reduction is not yet supported");
mlir::Type reductionType = getReductionType(converter, specifier);
+ List<Clause> clauses = makeClauses(construct.v.Clauses(), semaCtx);
+ const clause::Combiner &combiner =
+ appendCombiner(construct, clauses, semaCtx);
+
ReductionProcessor::GenCombinerCBTy genCombinerCB =
- processReductionCombiner(converter, symTable, semaCtx, specifier);
- const parser::OmpClauseList &initializer =
- declareReductionConstruct.v.Clauses();
- if (initializer.v.size() > 0) {
- List<Clause> clauses = makeClauses(initializer, semaCtx);
- ReductionProcessor::GenInitValueCBTy genInitValueCB;
- ClauseProcessor cp(converter, semaCtx, clauses);
- cp.processInitializer(symTable, genInitValueCB);
- const auto &identifier =
- std::get<parser::OmpReductionIdentifier>(specifier.t);
- const auto &designator =
- std::get<parser::ProcedureDesignator>(identifier.u);
- const auto &reductionName = std::get<parser::Name>(designator.u);
- bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
- ReductionProcessor::createDeclareReductionHelper<
- mlir::omp::DeclareReductionOp>(
- converter, reductionName.ToString(), reductionType,
- converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
- } else {
- TODO(converter.getCurrentLocation(),
- "declare reduction without an initializer clause is not yet "
- "supported");
- }
+ processReductionCombiner(converter, symTable, semaCtx, combiner);
+
+ ReductionProcessor::GenInitValueCBTy genInitValueCB;
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processInitializer(symTable, genInitValueCB);
+
+ const auto &identifier =
+ std::get<parser::OmpReductionIdentifier>(specifier.t);
+ const auto &designator = std::get<parser::ProcedureDesignator>(identifier.u);
+ const auto &reductionName = std::get<parser::Name>(designator.u);
+ bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+ ReductionProcessor::createDeclareReductionHelper<
+ mlir::omp::DeclareReductionOp>(
+ converter, reductionName.ToString(), reductionType,
+ converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
}
static void
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 24bdef9f88ed4..1f0d1be9adc00 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -488,27 +488,30 @@ static void InstantiateDeclareReduction(OmpDirectiveSpecification &spec) {
return;
}
- const OmpTypeNameList *typeNames{nullptr};
+ const OmpTypeNameList &typeNames{std::get<OmpTypeNameList>(rspec->t)};
if (auto *cexpr{
const_cast<OmpCombinerExpression *>(GetCombinerExpr(*rspec))}) {
- typeNames = &std::get<OmpTypeNameList>(rspec->t);
-
- InstantiateForTypes(*cexpr, *typeNames, OmpCombinerExpression::Variables());
+ InstantiateForTypes(*cexpr, typeNames, OmpCombinerExpression::Variables());
delete cexpr->state;
cexpr->state = nullptr;
- } else {
- // If there are no types, there is nothing else to do.
- return;
}
for (const OmpClause &clause : spec.Clauses().v) {
llvm::omp::Clause id{clause.Id()};
- if (id == llvm::omp::Clause::OMPC_initializer) {
+ if (id == llvm::omp::Clause::OMPC_combiner) {
+ if (auto *cexpr{
+ const_cast<OmpCombinerExpression *>(GetCombinerExpr(clause))}) {
+ InstantiateForTypes(
+ *cexpr, typeNames, OmpCombinerExpression::Variables());
+ delete cexpr->state;
+ cexpr->state = nullptr;
+ }
+ } else if (id == llvm::omp::Clause::OMPC_initializer) {
if (auto *iexpr{const_cast<OmpInitializerExpression *>(
GetInitializerExpr(clause))}) {
InstantiateForTypes(
- *iexpr, *typeNames, OmpInitializerExpression::Variables());
+ *iexpr, typeNames, OmpInitializerExpression::Variables());
delete iexpr->state;
iexpr->state = nullptr;
}
@@ -1316,6 +1319,8 @@ TYPE_PARSER(construct<OmpDetachClause>(Parser<OmpObject>{}))
TYPE_PARSER(construct<OmpHintClause>(scalarIntConstantExpr))
+TYPE_PARSER(construct<OmpCombinerClause>(Parser<OmpCombinerExpression>{}))
+
// init clause
TYPE_PARSER(construct<OmpInitClause>(
maybe(nonemptyList(Parser<OmpInitClause::Modifier>{}) / ":"),
@@ -1426,6 +1431,8 @@ TYPE_PARSER( //
"CAPTURE" >> construct<OmpClause>(construct<OmpClause::Capture>()) ||
"COLLAPSE" >> construct<OmpClause>(construct<OmpClause::Collapse>(
parenthesized(scalarIntConstantExpr))) ||
+ "COMBINER" >> construct<OmpClause>(construct<OmpClause::Combiner>(
+ parenthesized(Parser<OmpCombinerClause>{}))) ||
"COMPARE" >> construct<OmpClause>(construct<OmpClause::Compare>()) ||
"CONTAINS" >> construct<OmpClause>(construct<OmpClause::Contains>(
parenthesized(Parser<OmpContainsClause>{}))) ||
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index f96a5fca778e1..a9dbb55819b1e 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/open...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/172036
More information about the llvm-commits
mailing list