[flang] [llvm] [mlir] [OpenMP][flang] Lowering of OpenMP custom reductions to MLIR (PR #168417)

Jan Leyonberg via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 18 11:56:49 PST 2025


================
@@ -3563,12 +3565,137 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
     TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
 }
 
+static bool
+processReductionCombiner(lower::AbstractConverter &converter,
+                         lower::SymMap &symTable,
+                         semantics::SemanticsContext &semaCtx,
+                         const parser::OmpReductionSpecifier &specifier,
+                         ReductionProcessor::GenCombinerCBTy &genCombinerCB) {
+  const auto &combinerExpression =
+      std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
+          .value();
+  const parser::OmpStylizedInstance &combinerInstance =
+      combinerExpression.v.front();
+  const parser::OmpStylizedInstance::Instance &instance =
+      std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
+  if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
+    auto &expr = std::get<parser::Expr>(as->t);
+    genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
+                        mlir::Type type, mlir::Value lhs, mlir::Value rhs,
+                        bool isByRef) {
+      const auto &evalExpr = makeExpr(expr, semaCtx);
+      lower::SymMapScope scope(symTable);
+      const std::list<parser::OmpStylizedDeclaration> &declList =
+          std::get<std::list<parser::OmpStylizedDeclaration>>(
+              combinerInstance.t);
+      for (const parser::OmpStylizedDeclaration &decl : declList) {
+        auto &name = std::get<parser::ObjectName>(decl.var.t);
+        mlir::Value addr = lhs;
+        mlir::Type type = lhs.getType();
+        bool isRhs = name.ToString() == std::string("omp_in");
+        if (isRhs) {
+          addr = rhs;
+          type = rhs.getType();
+        }
+
+        assert(name.symbol && "Reduction object name does not have a symbol");
+        if (!fir::conformsWithPassByRef(type)) {
+          addr = builder.createTemporary(loc, type);
+          fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
+        }
+        fir::FortranVariableFlagsEnum extraFlags = {};
+        fir::FortranVariableFlagsAttr attributes =
+            Fortran::lower::translateSymbolAttributes(builder.getContext(),
+                                                      *name.symbol, extraFlags);
+        auto declareOp = hlfir::DeclareOp::create(
+            builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
+            0, attributes);
+        symTable.addVariableDefinition(*name.symbol, declareOp);
+      }
+
+      lower::StatementContext stmtCtx;
+      mlir::Value result = fir::getBase(
+          convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
+      if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
+        if (lhs.getType() == refType.getElementType())
+          result = fir::LoadOp::create(builder, loc, result);
+      stmtCtx.finalizeAndPop();
+      if (isByRef) {
+        fir::StoreOp::create(builder, loc, result, lhs);
+        mlir::omp::YieldOp::create(builder, loc, lhs);
+      } else {
+        mlir::omp::YieldOp::create(builder, loc, result);
+      }
+
+      return result;
+    };
+  }
+  return true;
+}
+
+// Getting the type from a symbol compared to a DeclSpec is simpler since we do
+// not need to consider derived vs intrinsic types. Semantics is guaranteed to
+// generate these symbols.
+static mlir::Type
+getReductionType(lower::AbstractConverter &converter,
+                 const parser::OmpReductionSpecifier &specifier) {
+  const auto &combinerExpression =
+      std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
+          .value();
+  const parser::OmpStylizedInstance &combinerInstance =
+      combinerExpression.v.front();
+  const std::list<parser::OmpStylizedDeclaration> &declList =
+      std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+  const parser::OmpStylizedDeclaration &decl = declList.front();
+  const auto &name = std::get<parser::ObjectName>(decl.var.t);
+  const auto &symbol = semantics::SymbolRef(*name.symbol);
+  mlir::Type reductionType = converter.genType(symbol);
+  return reductionType;
+}
+
 static void genOMP(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
     const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
-  if (!semaCtx.langOptions().OpenMPSimd)
-    TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
+  if (!semaCtx.langOptions().OpenMPSimd) {
+    const parser::OmpArgumentList &args{
+        declareReductionConstruct.v.Arguments()};
+    const parser::OmpArgument &arg{args.v.front()};
+    const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
+
+    if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
+      TODO(converter.getCurrentLocation(),
+           "multiple types in declare target is not yet supported");
----------------
jsjodin wrote:

Thanks for finding those.

https://github.com/llvm/llvm-project/pull/168417


More information about the llvm-commits mailing list