[flang-commits] [flang] d989ff9 - [flang][OpenMP] Add lowering of subroutine calls in custom reduction combiners (#169808)
via flang-commits
flang-commits at lists.llvm.org
Fri Nov 28 06:00:22 PST 2025
Author: Jan Leyonberg
Date: 2025-11-28T09:00:18-05:00
New Revision: d989ff93e2a073cb921cfcfeb9728a0b51892f1a
URL: https://github.com/llvm/llvm-project/commit/d989ff93e2a073cb921cfcfeb9728a0b51892f1a
DIFF: https://github.com/llvm/llvm-project/commit/d989ff93e2a073cb921cfcfeb9728a0b51892f1a.diff
LOG: [flang][OpenMP] Add lowering of subroutine calls in custom reduction combiners (#169808)
This patch adds support for lowering subroutine calls in custom
reduction combiners to MLIR.
Added:
flang/test/Lower/OpenMP/omp-declare-reduction-combsub.f90
Modified:
flang/lib/Lower/OpenMP/OpenMP.cpp
Removed:
################################################################################
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index abe65cdb2102f..0a200388a36e5 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -20,6 +20,7 @@
#include "flang/Common/idioms.h"
#include "flang/Evaluate/type.h"
#include "flang/Lower/Bridge.h"
+#include "flang/Lower/ConvertCall.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/ConvertVariable.h"
@@ -3582,19 +3583,32 @@ processReductionCombiner(lower::AbstractConverter &converter,
const parser::OmpStylizedInstance::Instance &instance =
std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
- const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u);
- if (!as) {
- TODO(converter.getCurrentLocation(),
- "A combiner that is a subroutine call is not yet supported");
+ std::optional<semantics::SomeExpr> evalExprOpt;
+ if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
+ auto &expr = std::get<parser::Expr>(as->t);
+ evalExprOpt = makeExpr(expr, semaCtx);
+ } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
+ if (call->typedCall) {
+ const auto &procRef = *call->typedCall;
+ evalExprOpt = semantics::SomeExpr{procRef};
+ } else {
+ TODO(converter.getCurrentLocation(),
+ "CallStmt without typedCall is not yet supported");
+ }
+ } else {
+ TODO(converter.getCurrentLocation(), "Unsupported combiner instance type");
}
- auto &expr = std::get<parser::Expr>(as->t);
- genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Type type, mlir::Value lhs, mlir::Value rhs,
- bool isByRef) {
- const auto &evalExpr = makeExpr(expr, semaCtx);
+
+ assert(evalExprOpt.has_value() && "evalExpr must be initialized");
+ semantics::SomeExpr evalExpr = *evalExprOpt;
+
+ genCombinerCB = [&, evalExpr](fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Type type, mlir::Value lhs,
+ mlir::Value rhs, bool isByRef) {
lower::SymMapScope scope(symTable);
const std::list<parser::OmpStylizedDeclaration> &declList =
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+ mlir::Value ompOutVar;
for (const parser::OmpStylizedDeclaration &decl : declList) {
auto &name = std::get<parser::ObjectName>(decl.var.t);
mlir::Value addr = lhs;
@@ -3617,15 +3631,32 @@ processReductionCombiner(lower::AbstractConverter &converter,
auto declareOp =
hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
{}, nullptr, nullptr, 0, attributes);
+ if (name.ToString() == "omp_out")
+ ompOutVar = declareOp.getResult(0);
symTable.addVariableDefinition(*name.symbol, declareOp);
}
lower::StatementContext stmtCtx;
- mlir::Value result = fir::getBase(
- convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
- if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
- if (lhs.getType() == refType.getElementType())
- result = fir::LoadOp::create(builder, loc, result);
+ mlir::Value result = common::visit(
+ common::visitors{
+ [&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
+ convertCallToHLFIR(loc, converter, procRef, std::nullopt,
+ symTable, stmtCtx);
+ auto outVal = fir::LoadOp::create(builder, loc, ompOutVar);
+ return outVal;
+ },
+ [&](const auto &expr) -> mlir::Value {
+ mlir::Value exprResult = fir::getBase(convertExprToValue(
+ loc, converter, evalExpr, symTable, stmtCtx));
+ // Optional load may be generated if we get a reference to the
+ // reduction type.
+ if (auto refType =
+ llvm::dyn_cast<fir::ReferenceType>(exprResult.getType()))
+ if (lhs.getType() == refType.getElementType())
+ exprResult = fir::LoadOp::create(builder, loc, exprResult);
+ return exprResult;
+ }},
+ evalExpr.u);
stmtCtx.finalizeAndPop();
if (isByRef) {
fir::StoreOp::create(builder, loc, result, lhs);
diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-combsub.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-combsub.f90
new file mode 100644
index 0000000000000..098b3f84aa2f3
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-combsub.f90
@@ -0,0 +1,60 @@
+! This test checks lowering of OpenMP declare reduction Directive, with combiner
+! via a subroutine call.
+
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+
+subroutine combine_me(out, in)
+ integer out, in
+ out = out + in
+end subroutine combine_me
+
+function func(x, n)
+ integer func
+ integer x(n)
+ integer res
+ interface
+ subroutine combine_me(out, in)
+ integer out, in
+ end subroutine combine_me
+ end interface
+!CHECK: omp.declare_reduction @red_add : i32 init {
+!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32):
+!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32
+!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32
+!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<i32>
+!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<i32>
+!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32
+!CHECK: omp.yield(%[[CONST_0]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32):
+!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32
+!CHECK: %[[OMP_IN:.*]] = fir.alloca i32
+!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<i32>
+!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<i32>
+!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: fir.call @_QPcombine_me(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> ()
+!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
+!CHECK: omp.yield(%[[OMP_OUT_VAL]] : i32)
+!CHECK: }
+!CHECK: func.func @_QPcombine_me(%[[OUT:.*]]: !fir.ref<i32> {fir.bindc_name = "out"}, %[[IN:.*]]: !fir.ref<i32> {fir.bindc_name = "in"}) {
+!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
+!CHECK: %[[IN_DECL:.*]]:2 = hlfir.declare %[[IN]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QFcombine_meEin"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[OUT_DECL:.*]]:2 = hlfir.declare %[[OUT]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QFcombine_meEout"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[OUT_VAL:.*]] = fir.load %[[OUT_DECL]]#0 : !fir.ref<i32>
+!CHECK: %[[IN_VAL:.*]] = fir.load %[[IN_DECL]]#0 : !fir.ref<i32>
+!CHECK: %[[SUM:.*]] = arith.addi %[[OUT_VAL]], %[[IN_VAL]] : i32
+!CHECK: hlfir.assign %[[SUM]] to %[[OUT_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK: return
+!CHECK: }
+!$omp declare reduction(red_add:integer(4):combine_me(omp_out,omp_in)) initializer(omp_priv=0)
+ res=0
+!$omp simd reduction(red_add:res)
+ do i=1,n
+ res=res+x(i)
+ enddo
+ func=res
+end function func
+
More information about the flang-commits
mailing list