[flang] [llvm] [mlir] [OpenMP][flang] Lowering of OpenMP custom reductions to MLIR (PR #168417)
Jan Leyonberg via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 18 11:56:49 PST 2025
================
@@ -3563,12 +3565,137 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
}
+static bool
+processReductionCombiner(lower::AbstractConverter &converter,
+ lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx,
+ const parser::OmpReductionSpecifier &specifier,
+ ReductionProcessor::GenCombinerCBTy &genCombinerCB) {
+ const auto &combinerExpression =
+ std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
+ .value();
+ const parser::OmpStylizedInstance &combinerInstance =
+ combinerExpression.v.front();
+ const parser::OmpStylizedInstance::Instance &instance =
+ std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
+ if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
+ auto &expr = std::get<parser::Expr>(as->t);
+ genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Type type, mlir::Value lhs, mlir::Value rhs,
+ bool isByRef) {
+ const auto &evalExpr = makeExpr(expr, semaCtx);
+ lower::SymMapScope scope(symTable);
+ const std::list<parser::OmpStylizedDeclaration> &declList =
+ std::get<std::list<parser::OmpStylizedDeclaration>>(
+ combinerInstance.t);
+ for (const parser::OmpStylizedDeclaration &decl : declList) {
+ auto &name = std::get<parser::ObjectName>(decl.var.t);
+ mlir::Value addr = lhs;
+ mlir::Type type = lhs.getType();
+ bool isRhs = name.ToString() == std::string("omp_in");
+ if (isRhs) {
+ addr = rhs;
+ type = rhs.getType();
+ }
+
+ assert(name.symbol && "Reduction object name does not have a symbol");
+ if (!fir::conformsWithPassByRef(type)) {
+ addr = builder.createTemporary(loc, type);
+ fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
+ }
+ fir::FortranVariableFlagsEnum extraFlags = {};
+ fir::FortranVariableFlagsAttr attributes =
+ Fortran::lower::translateSymbolAttributes(builder.getContext(),
+ *name.symbol, extraFlags);
+ auto declareOp = hlfir::DeclareOp::create(
+ builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
+ 0, attributes);
+ symTable.addVariableDefinition(*name.symbol, declareOp);
+ }
+
+ lower::StatementContext stmtCtx;
+ mlir::Value result = fir::getBase(
+ convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
+ if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
+ if (lhs.getType() == refType.getElementType())
+ result = fir::LoadOp::create(builder, loc, result);
+ stmtCtx.finalizeAndPop();
+ if (isByRef) {
+ fir::StoreOp::create(builder, loc, result, lhs);
+ mlir::omp::YieldOp::create(builder, loc, lhs);
+ } else {
+ mlir::omp::YieldOp::create(builder, loc, result);
+ }
+
+ return result;
+ };
+ }
+ return true;
+}
+
+// Getting the type from a symbol compared to a DeclSpec is simpler since we do
+// not need to consider derived vs intrinsic types. Semantics is guaranteed to
+// generate these symbols.
+static mlir::Type
+getReductionType(lower::AbstractConverter &converter,
+ const parser::OmpReductionSpecifier &specifier) {
+ const auto &combinerExpression =
+ std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
+ .value();
+ const parser::OmpStylizedInstance &combinerInstance =
+ combinerExpression.v.front();
+ const std::list<parser::OmpStylizedDeclaration> &declList =
+ std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+ const parser::OmpStylizedDeclaration &decl = declList.front();
+ const auto &name = std::get<parser::ObjectName>(decl.var.t);
+ const auto &symbol = semantics::SymbolRef(*name.symbol);
+ mlir::Type reductionType = converter.genType(symbol);
+ return reductionType;
+}
+
static void genOMP(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
- if (!semaCtx.langOptions().OpenMPSimd)
- TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
+ if (!semaCtx.langOptions().OpenMPSimd) {
+ const parser::OmpArgumentList &args{
+ declareReductionConstruct.v.Arguments()};
+ const parser::OmpArgument &arg{args.v.front()};
+ const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
+
+ if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
+ TODO(converter.getCurrentLocation(),
+ "multiple types in declare target is not yet supported");
----------------
jsjodin wrote:
Thanks for finding those.
https://github.com/llvm/llvm-project/pull/168417
More information about the llvm-commits
mailing list