[flang-commits] [flang] [Flang][OpenMP] Fix crash and IR errors for user-defined reduction on allocatable variables (PR #186765)

via flang-commits flang-commits at lists.llvm.org
Mon Mar 16 03:21:11 PDT 2026


https://github.com/Ritanya-B-Bharadwaj updated https://github.com/llvm/llvm-project/pull/186765

>From a00cc6584a801f2b969ee35d707f8ce436ca19a6 Mon Sep 17 00:00:00 2001
From: Ritanya B Bharadwaj <ritanya.b.bharadwaj at gmail.com>
Date: Mon, 16 Mar 2026 05:07:05 -0500
Subject: [PATCH 1/2] fixing flang issue #186743

---
 .../lib/Lower/Support/ReductionProcessor.cpp  | 181 ++++++++++++++++++
 .../OpenMP/declare-reduction-allocatable.f90  | 117 +++++++++++
 2 files changed, 298 insertions(+)
 create mode 100644 flang/test/Lower/OpenMP/declare-reduction-allocatable.f90

diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index e0cba4c512258..078d8f24d5f7a 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -23,6 +23,7 @@
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/IRMapping.h"
 #include "llvm/Support/CommandLine.h"
 #include <type_traits>
 
@@ -803,6 +804,185 @@ bool ReductionProcessor::processReductionArguments(
                          &redOperator.u)) {
         if (!ReductionProcessor::supportedIntrinsicProcReduction(
                 *reductionIntrinsic)) {
+          if (isByRef) {
+            // User-defined reduction on allocatable/pointer variable-
+            // we need a new declare_reduction for the boxed type, reusing
+            // the init value and combiner from the existing one. 
+            semantics::Symbol *sym = reductionIntrinsic->v.sym();
+            std::string baseName = sym->name().ToString();
+            mlir::ModuleOp module = builder.getModule();
+            auto existingDecl = module.lookupSymbol<OpType>(baseName);
+            if (!existingDecl) {
+              TODO(currentLocation,
+                   "User-defined reductions on allocatable or pointer "
+                   "variables: cannot find base reduction declaration");
+            }
+
+            std::string byrefName = getReductionName(
+                baseName, builder.getKindMap(), redType, isByRef);
+
+            mlir::Region &existingInitRegion =
+                existingDecl.getInitializerRegion();
+            auto genInitValueCB =
+                [&existingInitRegion](fir::FirOpBuilder &builder,
+                                      mlir::Location loc, mlir::Type elemTy,
+                                      mlir::Value) -> mlir::Value {
+              // unwrap box type to get the scalar element type
+              mlir::Type scalarTy = unwrapSeqOrBoxedType(elemTy);
+              // find a constant-producing op in the existing init region
+              mlir::Operation *constOp = nullptr;
+              existingInitRegion.walk([&](mlir::Operation *op) {
+                if (constOp)
+                  return;
+                if (auto arithConst =
+                        mlir::dyn_cast<mlir::arith::ConstantOp>(op)) {
+                  if (arithConst.getType() == scalarTy)
+                    constOp = op;
+                  return;
+                }
+                if (mlir::isa<fir::StringLitOp>(op)) {
+                  constOp = op;
+                  return;
+                }
+                if (mlir::isa<fir::AddrOfOp>(op)) {
+                  constOp = op;
+                  return;
+                }
+              });
+
+              if (constOp) {
+                mlir::IRMapping mapper;
+                mlir::Value cloned =
+                    builder.clone(*constOp, mapper)->getResult(0);
+                // load if the cloned op produces a reference
+                if (fir::isa_ref_type(cloned.getType()))
+                  cloned = fir::LoadOp::create(builder, loc, cloned);
+                if (cloned.getType() != scalarTy)
+                  cloned = builder.createConvert(loc, scalarTy, cloned);
+                return cloned;
+              }
+              // Fallback: zero-initialize for trivial types.
+              if (fir::isa_integer(scalarTy))
+                return builder.createIntegerConstant(loc, scalarTy, 0);
+              if (mlir::isa<mlir::FloatType>(scalarTy))
+                return builder.createRealConstant(loc, scalarTy,
+                                                  llvm::APFloat(0.0));
+              if (fir::isa_char(scalarTy))
+                return mlir::Value{};
+              TODO(loc, "User-defined reduction: unsupported init "
+                        "value type for allocatable wrapper");
+              return mlir::Value{};
+            };
+
+            // combiner: unbox, apply the existing combiner, rebox.
+            mlir::Region &existingCombinerRegion =
+                existingDecl.getReductionRegion();
+            bool existingCombinerIsByRef = fir::isa_ref_type(
+                existingCombinerRegion.front().getArgument(0).getType());
+            auto genCombinerCB = [&existingCombinerRegion,
+                                  existingCombinerIsByRef](
+                                     fir::FirOpBuilder &builder,
+                                     mlir::Location loc, mlir::Type type,
+                                     mlir::Value op1, mlir::Value op2, bool) {
+              // clone the existing combiner ops with remapped block args.
+              auto cloneCombiner =
+                  [&](mlir::Value lhs,
+                      mlir::Value rhs) -> std::optional<mlir::Value> {
+                mlir::IRMapping mapper;
+                mlir::Block &block = existingCombinerRegion.front();
+                mapper.map(block.getArgument(0), lhs);
+                mapper.map(block.getArgument(1), rhs);
+                mlir::Value result;
+                for (mlir::Operation &op : block) {
+                  if (auto yieldOp = mlir::dyn_cast<mlir::omp::YieldOp>(op)) {
+                    if (!existingCombinerIsByRef)
+                      result = mapper.lookup(yieldOp.getOperand(0));
+                    break;
+                  }
+                  builder.clone(op, mapper);
+                }
+                if (result)
+                  return result;
+                return std::nullopt;
+              };
+
+              auto boxTy =
+                  mlir::dyn_cast<fir::BaseBoxType>(fir::unwrapRefType(type));
+              if (!boxTy) {
+                TODO(loc, "User-defined reductions: unsupported byref type");
+                return;
+              }
+              // seqTy is non-null for array allocatables.
+              auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
+                  fir::dyn_cast_ptrOrBoxEleTy(boxTy));
+
+              mlir::Value lhsBox = fir::LoadOp::create(builder, loc, op1);
+              mlir::Value rhsBox = fir::LoadOp::create(builder, loc, op2);
+
+              if (!seqTy) {
+                // Scalar allocatable
+                mlir::Value lhsAddr =
+                    fir::BoxAddrOp::create(builder, loc, lhsBox);
+                mlir::Value rhsAddr =
+                    fir::BoxAddrOp::create(builder, loc, rhsBox);
+                if (existingCombinerIsByRef) {
+                  mlir::Type expectedArgTy =
+                      existingCombinerRegion.front().getArgument(0).getType();
+                  lhsAddr = builder.createConvert(loc, expectedArgTy, lhsAddr);
+                  rhsAddr = builder.createConvert(loc, expectedArgTy, rhsAddr);
+                  cloneCombiner(lhsAddr, rhsAddr);
+                } else {
+                  mlir::Value lhsVal =
+                      fir::LoadOp::create(builder, loc, lhsAddr);
+                  mlir::Value rhsVal =
+                      fir::LoadOp::create(builder, loc, rhsAddr);
+                  if (auto result = cloneCombiner(lhsVal, rhsVal))
+                    fir::StoreOp::create(builder, loc, *result, lhsAddr);
+                }
+              } else {
+                // array allocatable: iterate elements
+                fir::ShapeShiftOp shapeShift =
+                    getShapeShift(builder, loc, lhsBox,
+                                  /*cannotHaveNonDefaultLowerBounds=*/false,
+                                  /*useDefaultLowerBounds=*/true);
+                hlfir::LoopNest nest =
+                    hlfir::genLoopNest(loc, builder, shapeShift.getExtents(),
+                                       /*isUnordered=*/true);
+                builder.setInsertionPointToStart(nest.body);
+                mlir::Type eleTy = seqTy.getEleTy();
+                mlir::Type refTy = fir::ReferenceType::get(
+                    eleTy, fir::isa_volatile_type(eleTy));
+                auto lhsEleAddr = fir::ArrayCoorOp::create(
+                    builder, loc, refTy, lhsBox, shapeShift,
+                    /*slice=*/mlir::Value{}, nest.oneBasedIndices,
+                    /*typeparms=*/mlir::ValueRange{});
+                auto rhsEleAddr = fir::ArrayCoorOp::create(
+                    builder, loc, refTy, rhsBox, shapeShift,
+                    /*slice=*/mlir::Value{}, nest.oneBasedIndices,
+                    /*typeparms=*/mlir::ValueRange{});
+                if (existingCombinerIsByRef) {
+                  cloneCombiner(lhsEleAddr, rhsEleAddr);
+                } else {
+                  mlir::Value lhsEle =
+                      fir::LoadOp::create(builder, loc, lhsEleAddr);
+                  mlir::Value rhsEle =
+                      fir::LoadOp::create(builder, loc, rhsEleAddr);
+                  if (auto result = cloneCombiner(lhsEle, rhsEle))
+                    fir::StoreOp::create(builder, loc, *result, lhsEleAddr);
+                }
+                builder.setInsertionPointAfter(nest.outerOp);
+              }
+              mlir::omp::YieldOp::create(builder, loc, op1);
+            };
+
+            OpType decl = createDeclareReductionHelper<OpType>(
+                converter, byrefName, redType, currentLocation, isByRef,
+                genCombinerCB, genInitValueCB);
+            reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+                builder.getContext(), decl.getSymName()));
+            ++idx;
+            continue;
+          }
           // Custom reductions we can just add to the symbols without
           // generating the declare reduction op.
           semantics::Symbol *sym = reductionIntrinsic->v.sym();
@@ -866,3 +1046,4 @@ int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
 } // namespace omp
 } // namespace lower
 } // namespace Fortran
+
diff --git a/flang/test/Lower/OpenMP/declare-reduction-allocatable.f90 b/flang/test/Lower/OpenMP/declare-reduction-allocatable.f90
new file mode 100644
index 0000000000000..b0b0b161c3c7c
--- /dev/null
+++ b/flang/test/Lower/OpenMP/declare-reduction-allocatable.f90
@@ -0,0 +1,117 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=45 %s -o - | FileCheck %s
+
+subroutine test_udr_allocatable()
+  implicit none
+  integer :: i
+  integer, allocatable :: a, b(:), c(:,:)
+
+  !$omp declare reduction (foo : integer : omp_out = omp_out + omp_in) &
+  !$omp & initializer (omp_priv = 0)
+
+  allocate(a, b(4), c(3,2))
+  a = 0
+  b = 0
+  c = 0
+
+  !$omp parallel do reduction(foo : a)
+  do i = 1, 10
+    a = a + i
+  end do
+
+  !$omp parallel do reduction(foo : b)
+  do i = 1, 10
+    b = b + i
+  end do
+
+  !$omp parallel do reduction(foo : c)
+  do i = 1, 10
+    c = c + i
+  end do
+end subroutine
+
+! CHECK-LABEL: omp.declare_reduction @foo_byref_box_heap_UxUxi32 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
+! CHECK-SAME:  attributes {byref_element_type = !fir.array<?x?xi32>}
+! CHECK:       alloc {
+! CHECK:         fir.alloca !fir.box<!fir.heap<!fir.array<?x?xi32>>>
+! CHECK:         omp.yield
+! CHECK:       } init {
+! CHECK:         %[[C0_2D:.*]] = arith.constant 0 : i32
+! CHECK:         omp.yield
+! CHECK:       } combiner {
+! CHECK:       ^bb0(%[[ARG0_2D:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, %[[ARG1_2D:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>):
+! CHECK:         %[[LHS_BOX_2D:.*]] = fir.load %[[ARG0_2D]]
+! CHECK:         %[[RHS_BOX_2D:.*]] = fir.load %[[ARG1_2D]]
+! CHECK:         fir.shape_shift
+! CHECK:         fir.do_loop {{.*}} unordered {
+! CHECK:           fir.do_loop {{.*}} unordered {
+! CHECK:             fir.array_coor %[[LHS_BOX_2D]]
+! CHECK:             fir.array_coor %[[RHS_BOX_2D]]
+! CHECK:             arith.addi
+! CHECK:             fir.store
+! CHECK:           }
+! CHECK:         }
+! CHECK:         omp.yield(%[[ARG0_2D]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
+! CHECK:       } cleanup {
+! CHECK:         fir.freemem
+! CHECK:         omp.yield
+
+! CHECK-LABEL: omp.declare_reduction @foo_byref_box_heap_Uxi32 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK-SAME:  attributes {byref_element_type = !fir.array<?xi32>}
+! CHECK:       alloc {
+! CHECK:         fir.alloca !fir.box<!fir.heap<!fir.array<?xi32>>>
+! CHECK:         omp.yield
+! CHECK:       } init {
+! CHECK:         %[[C0_1D:.*]] = arith.constant 0 : i32
+! CHECK:         omp.yield
+! CHECK:       } combiner {
+! CHECK:       ^bb0(%[[ARG0_1D:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, %[[ARG1_1D:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>):
+! CHECK:         %[[LHS_BOX_1D:.*]] = fir.load %[[ARG0_1D]]
+! CHECK:         %[[RHS_BOX_1D:.*]] = fir.load %[[ARG1_1D]]
+! CHECK:         fir.shape_shift
+! CHECK:         fir.do_loop {{.*}} unordered {
+! CHECK:           fir.array_coor %[[LHS_BOX_1D]]
+! CHECK:           fir.array_coor %[[RHS_BOX_1D]]
+! CHECK:           arith.addi
+! CHECK:           fir.store
+! CHECK:         }
+! CHECK:         omp.yield(%[[ARG0_1D]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK:       } cleanup {
+! CHECK:         fir.freemem
+! CHECK:         omp.yield
+
+! CHECK-LABEL: omp.declare_reduction @foo_byref_box_heap_i32 : !fir.ref<!fir.box<!fir.heap<i32>>>
+! CHECK-SAME:  attributes {byref_element_type = i32}
+! CHECK:       alloc {
+! CHECK:         fir.alloca !fir.box<!fir.heap<i32>>
+! CHECK:         omp.yield
+! CHECK:       } init {
+! CHECK:         %[[C0_S:.*]] = arith.constant 0 : i32
+! CHECK:         omp.yield
+! CHECK:       } combiner {
+! CHECK:       ^bb0(%[[ARG0_S:.*]]: !fir.ref<!fir.box<!fir.heap<i32>>>, %[[ARG1_S:.*]]: !fir.ref<!fir.box<!fir.heap<i32>>>):
+! CHECK:         %[[LHS_BOX_S:.*]] = fir.load %[[ARG0_S]]
+! CHECK:         %[[RHS_BOX_S:.*]] = fir.load %[[ARG1_S]]
+! CHECK:         %[[LHS_ADDR_S:.*]] = fir.box_addr %[[LHS_BOX_S]]
+! CHECK:         %[[RHS_ADDR_S:.*]] = fir.box_addr %[[RHS_BOX_S]]
+! CHECK:         fir.load %[[LHS_ADDR_S]]
+! CHECK:         fir.load %[[RHS_ADDR_S]]
+! CHECK:         arith.addi
+! CHECK:         fir.store %{{.*}} to %[[LHS_ADDR_S]]
+! CHECK:         omp.yield(%[[ARG0_S]] : !fir.ref<!fir.box<!fir.heap<i32>>>)
+! CHECK:       } cleanup {
+! CHECK:         fir.freemem
+! CHECK:         omp.yield
+
+! CHECK-LABEL: omp.declare_reduction @foo : i32
+! CHECK:       init {
+! CHECK:         %[[C0_BASE:.*]] = arith.constant 0 : i32
+! CHECK:         omp.yield(%[[C0_BASE]] : i32)
+! CHECK:       } combiner {
+! CHECK:       ^bb0(%[[LHS_BASE:.*]]: i32, %[[RHS_BASE:.*]]: i32):
+! CHECK:         arith.addi
+! CHECK:         omp.yield
+
+! CHECK-LABEL: func.func @_QPtest_udr_allocatable
+! CHECK:         omp.wsloop {{.*}} reduction(byref @foo_byref_box_heap_i32 %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.heap<i32>>>)
+! CHECK:         omp.wsloop {{.*}} reduction(byref @foo_byref_box_heap_Uxi32 %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK:         omp.wsloop {{.*}} reduction(byref @foo_byref_box_heap_UxUxi32 %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)

>From b045a380311f0e8d9b0edc3c25a4f248234d1329 Mon Sep 17 00:00:00 2001
From: Ritanya-B-Bharadwaj <ritanya.b.bharadwaj at gmail.com>
Date: Mon, 16 Mar 2026 15:51:01 +0530
Subject: [PATCH 2/2] Update ReductionProcessor.cpp

---
 flang/lib/Lower/Support/ReductionProcessor.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index 078d8f24d5f7a..57efc982f908e 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -1046,4 +1046,3 @@ int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
 } // namespace omp
 } // namespace lower
 } // namespace Fortran
-



More information about the flang-commits mailing list