[flang-commits] [flang] d989ff9 - [flang][OpenMP] Add lowering of subroutine calls in custom reduction combiners (#169808)

via flang-commits flang-commits at lists.llvm.org
Fri Nov 28 06:00:22 PST 2025


Author: Jan Leyonberg
Date: 2025-11-28T09:00:18-05:00
New Revision: d989ff93e2a073cb921cfcfeb9728a0b51892f1a

URL: https://github.com/llvm/llvm-project/commit/d989ff93e2a073cb921cfcfeb9728a0b51892f1a
DIFF: https://github.com/llvm/llvm-project/commit/d989ff93e2a073cb921cfcfeb9728a0b51892f1a.diff

LOG: [flang][OpenMP] Add lowering of subroutine calls in custom reduction combiners (#169808)

This patch adds support for lowering subroutine calls in custom
reduction combiners to MLIR.

Added: 
    flang/test/Lower/OpenMP/omp-declare-reduction-combsub.f90

Modified: 
    flang/lib/Lower/OpenMP/OpenMP.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index abe65cdb2102f..0a200388a36e5 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -20,6 +20,7 @@
 #include "flang/Common/idioms.h"
 #include "flang/Evaluate/type.h"
 #include "flang/Lower/Bridge.h"
+#include "flang/Lower/ConvertCall.h"
 #include "flang/Lower/ConvertExpr.h"
 #include "flang/Lower/ConvertExprToHLFIR.h"
 #include "flang/Lower/ConvertVariable.h"
@@ -3582,19 +3583,32 @@ processReductionCombiner(lower::AbstractConverter &converter,
   const parser::OmpStylizedInstance::Instance &instance =
       std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
 
-  const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u);
-  if (!as) {
-    TODO(converter.getCurrentLocation(),
-         "A combiner that is a subroutine call is not yet supported");
+  std::optional<semantics::SomeExpr> evalExprOpt;
+  if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
+    auto &expr = std::get<parser::Expr>(as->t);
+    evalExprOpt = makeExpr(expr, semaCtx);
+  } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
+    if (call->typedCall) {
+      const auto &procRef = *call->typedCall;
+      evalExprOpt = semantics::SomeExpr{procRef};
+    } else {
+      TODO(converter.getCurrentLocation(),
+           "CallStmt without typedCall is not yet supported");
+    }
+  } else {
+    TODO(converter.getCurrentLocation(), "Unsupported combiner instance type");
   }
-  auto &expr = std::get<parser::Expr>(as->t);
-  genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
-                      mlir::Type type, mlir::Value lhs, mlir::Value rhs,
-                      bool isByRef) {
-    const auto &evalExpr = makeExpr(expr, semaCtx);
+
+  assert(evalExprOpt.has_value() && "evalExpr must be initialized");
+  semantics::SomeExpr evalExpr = *evalExprOpt;
+
+  genCombinerCB = [&, evalExpr](fir::FirOpBuilder &builder, mlir::Location loc,
+                                mlir::Type type, mlir::Value lhs,
+                                mlir::Value rhs, bool isByRef) {
     lower::SymMapScope scope(symTable);
     const std::list<parser::OmpStylizedDeclaration> &declList =
         std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+    mlir::Value ompOutVar;
     for (const parser::OmpStylizedDeclaration &decl : declList) {
       auto &name = std::get<parser::ObjectName>(decl.var.t);
       mlir::Value addr = lhs;
@@ -3617,15 +3631,32 @@ processReductionCombiner(lower::AbstractConverter &converter,
       auto declareOp =
           hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
                                    {}, nullptr, nullptr, 0, attributes);
+      if (name.ToString() == "omp_out")
+        ompOutVar = declareOp.getResult(0);
       symTable.addVariableDefinition(*name.symbol, declareOp);
     }
 
     lower::StatementContext stmtCtx;
-    mlir::Value result = fir::getBase(
-        convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
-    if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
-      if (lhs.getType() == refType.getElementType())
-        result = fir::LoadOp::create(builder, loc, result);
+    mlir::Value result = common::visit(
+        common::visitors{
+            [&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
+              convertCallToHLFIR(loc, converter, procRef, std::nullopt,
+                                 symTable, stmtCtx);
+              auto outVal = fir::LoadOp::create(builder, loc, ompOutVar);
+              return outVal;
+            },
+            [&](const auto &expr) -> mlir::Value {
+              mlir::Value exprResult = fir::getBase(convertExprToValue(
+                  loc, converter, evalExpr, symTable, stmtCtx));
+              // Optional load may be generated if we get a reference to the
+              // reduction type.
+              if (auto refType =
+                      llvm::dyn_cast<fir::ReferenceType>(exprResult.getType()))
+                if (lhs.getType() == refType.getElementType())
+                  exprResult = fir::LoadOp::create(builder, loc, exprResult);
+              return exprResult;
+            }},
+        evalExpr.u);
     stmtCtx.finalizeAndPop();
     if (isByRef) {
       fir::StoreOp::create(builder, loc, result, lhs);

diff  --git a/flang/test/Lower/OpenMP/omp-declare-reduction-combsub.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-combsub.f90
new file mode 100644
index 0000000000000..098b3f84aa2f3
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-declare-reduction-combsub.f90
@@ -0,0 +1,60 @@
+! This test checks lowering of OpenMP declare reduction Directive, with combiner
+! via a subroutine call.
+
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+
+subroutine combine_me(out, in)
+  integer out, in
+  out = out + in
+end subroutine combine_me
+
+function func(x, n)
+  integer func
+  integer x(n)
+  integer res
+  interface
+     subroutine combine_me(out, in)
+       integer out, in
+     end subroutine combine_me
+  end interface
+!CHECK:  omp.declare_reduction @red_add : i32 init {
+!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32):
+!CHECK:    %[[OMP_PRIV:.*]] = fir.alloca i32
+!CHECK:    %[[OMP_ORIG:.*]] = fir.alloca i32
+!CHECK:    fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<i32>
+!CHECK:    %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<i32>
+!CHECK:    %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    %[[CONST_0:.*]] = arith.constant 0 : i32
+!CHECK:    omp.yield(%[[CONST_0]] : i32)
+!CHECK:  } combiner {
+!CHECK:  ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32):
+!CHECK:    %[[OMP_OUT:.*]] = fir.alloca i32
+!CHECK:    %[[OMP_IN:.*]] = fir.alloca i32
+!CHECK:    fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<i32>
+!CHECK:    %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<i32>
+!CHECK:    %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    fir.call @_QPcombine_me(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> ()
+!CHECK:    %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
+!CHECK:    omp.yield(%[[OMP_OUT_VAL]] : i32)
+!CHECK:  }
+!CHECK:  func.func @_QPcombine_me(%[[OUT:.*]]: !fir.ref<i32> {fir.bindc_name = "out"}, %[[IN:.*]]: !fir.ref<i32> {fir.bindc_name = "in"}) {
+!CHECK:    %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
+!CHECK:    %[[IN_DECL:.*]]:2 = hlfir.declare %[[IN]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QFcombine_meEin"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    %[[OUT_DECL:.*]]:2 = hlfir.declare %[[OUT]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QFcombine_meEout"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    %[[OUT_VAL:.*]] = fir.load %[[OUT_DECL]]#0 : !fir.ref<i32>
+!CHECK:    %[[IN_VAL:.*]] = fir.load %[[IN_DECL]]#0 : !fir.ref<i32>
+!CHECK:    %[[SUM:.*]] = arith.addi %[[OUT_VAL]], %[[IN_VAL]] : i32
+!CHECK:    hlfir.assign %[[SUM]] to %[[OUT_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK:    return
+!CHECK:  }
+!$omp declare reduction(red_add:integer(4):combine_me(omp_out,omp_in)) initializer(omp_priv=0)
+  res=0
+!$omp simd reduction(red_add:res)
+  do i=1,n
+     res=res+x(i)
+  enddo
+  func=res
+end function func
+


        


More information about the flang-commits mailing list