[flang-commits] [flang] [Flang][OpenMP] Fix crash and IR errors for user-defined reduction on allocatable variables (PR #186765)
via flang-commits
flang-commits at lists.llvm.org
Mon Mar 16 03:21:11 PDT 2026
https://github.com/Ritanya-B-Bharadwaj updated https://github.com/llvm/llvm-project/pull/186765
>From a00cc6584a801f2b969ee35d707f8ce436ca19a6 Mon Sep 17 00:00:00 2001
From: Ritanya B Bharadwaj <ritanya.b.bharadwaj at gmail.com>
Date: Mon, 16 Mar 2026 05:07:05 -0500
Subject: [PATCH 1/2] fixing flang issue #186743
---
.../lib/Lower/Support/ReductionProcessor.cpp | 181 ++++++++++++++++++
.../OpenMP/declare-reduction-allocatable.f90 | 117 +++++++++++
2 files changed, 298 insertions(+)
create mode 100644 flang/test/Lower/OpenMP/declare-reduction-allocatable.f90
diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index e0cba4c512258..078d8f24d5f7a 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -23,6 +23,7 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/IRMapping.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>
@@ -803,6 +804,185 @@ bool ReductionProcessor::processReductionArguments(
&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic)) {
+ if (isByRef) {
+ // User-defined reduction on allocatable/pointer variable-
+ // we need a new declare_reduction for the boxed type, reusing
+ // the init value and combiner from the existing one.
+ semantics::Symbol *sym = reductionIntrinsic->v.sym();
+ std::string baseName = sym->name().ToString();
+ mlir::ModuleOp module = builder.getModule();
+ auto existingDecl = module.lookupSymbol<OpType>(baseName);
+ if (!existingDecl) {
+ TODO(currentLocation,
+ "User-defined reductions on allocatable or pointer "
+ "variables: cannot find base reduction declaration");
+ }
+
+ std::string byrefName = getReductionName(
+ baseName, builder.getKindMap(), redType, isByRef);
+
+ mlir::Region &existingInitRegion =
+ existingDecl.getInitializerRegion();
+ auto genInitValueCB =
+ [&existingInitRegion](fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Type elemTy,
+ mlir::Value) -> mlir::Value {
+ // unwrap box type to get the scalar element type
+ mlir::Type scalarTy = unwrapSeqOrBoxedType(elemTy);
+ // find a constant-producing op in the existing init region
+ mlir::Operation *constOp = nullptr;
+ existingInitRegion.walk([&](mlir::Operation *op) {
+ if (constOp)
+ return;
+ if (auto arithConst =
+ mlir::dyn_cast<mlir::arith::ConstantOp>(op)) {
+ if (arithConst.getType() == scalarTy)
+ constOp = op;
+ return;
+ }
+ if (mlir::isa<fir::StringLitOp>(op)) {
+ constOp = op;
+ return;
+ }
+ if (mlir::isa<fir::AddrOfOp>(op)) {
+ constOp = op;
+ return;
+ }
+ });
+
+ if (constOp) {
+ mlir::IRMapping mapper;
+ mlir::Value cloned =
+ builder.clone(*constOp, mapper)->getResult(0);
+ // load if the cloned op produces a reference
+ if (fir::isa_ref_type(cloned.getType()))
+ cloned = fir::LoadOp::create(builder, loc, cloned);
+ if (cloned.getType() != scalarTy)
+ cloned = builder.createConvert(loc, scalarTy, cloned);
+ return cloned;
+ }
+ // Fallback: zero-initialize for trivial types.
+ if (fir::isa_integer(scalarTy))
+ return builder.createIntegerConstant(loc, scalarTy, 0);
+ if (mlir::isa<mlir::FloatType>(scalarTy))
+ return builder.createRealConstant(loc, scalarTy,
+ llvm::APFloat(0.0));
+ if (fir::isa_char(scalarTy))
+ return mlir::Value{};
+ TODO(loc, "User-defined reduction: unsupported init "
+ "value type for allocatable wrapper");
+ return mlir::Value{};
+ };
+
+ // combiner: unbox, apply the existing combiner, rebox.
+ mlir::Region &existingCombinerRegion =
+ existingDecl.getReductionRegion();
+ bool existingCombinerIsByRef = fir::isa_ref_type(
+ existingCombinerRegion.front().getArgument(0).getType());
+ auto genCombinerCB = [&existingCombinerRegion,
+ existingCombinerIsByRef](
+ fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Type type,
+ mlir::Value op1, mlir::Value op2, bool) {
+ // clone the existing combiner ops with remapped block args.
+ auto cloneCombiner =
+ [&](mlir::Value lhs,
+ mlir::Value rhs) -> std::optional<mlir::Value> {
+ mlir::IRMapping mapper;
+ mlir::Block &block = existingCombinerRegion.front();
+ mapper.map(block.getArgument(0), lhs);
+ mapper.map(block.getArgument(1), rhs);
+ mlir::Value result;
+ for (mlir::Operation &op : block) {
+ if (auto yieldOp = mlir::dyn_cast<mlir::omp::YieldOp>(op)) {
+ if (!existingCombinerIsByRef)
+ result = mapper.lookup(yieldOp.getOperand(0));
+ break;
+ }
+ builder.clone(op, mapper);
+ }
+ if (result)
+ return result;
+ return std::nullopt;
+ };
+
+ auto boxTy =
+ mlir::dyn_cast<fir::BaseBoxType>(fir::unwrapRefType(type));
+ if (!boxTy) {
+ TODO(loc, "User-defined reductions: unsupported byref type");
+ return;
+ }
+ // seqTy is non-null for array allocatables.
+ auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
+ fir::dyn_cast_ptrOrBoxEleTy(boxTy));
+
+ mlir::Value lhsBox = fir::LoadOp::create(builder, loc, op1);
+ mlir::Value rhsBox = fir::LoadOp::create(builder, loc, op2);
+
+ if (!seqTy) {
+ // Scalar allocatable
+ mlir::Value lhsAddr =
+ fir::BoxAddrOp::create(builder, loc, lhsBox);
+ mlir::Value rhsAddr =
+ fir::BoxAddrOp::create(builder, loc, rhsBox);
+ if (existingCombinerIsByRef) {
+ mlir::Type expectedArgTy =
+ existingCombinerRegion.front().getArgument(0).getType();
+ lhsAddr = builder.createConvert(loc, expectedArgTy, lhsAddr);
+ rhsAddr = builder.createConvert(loc, expectedArgTy, rhsAddr);
+ cloneCombiner(lhsAddr, rhsAddr);
+ } else {
+ mlir::Value lhsVal =
+ fir::LoadOp::create(builder, loc, lhsAddr);
+ mlir::Value rhsVal =
+ fir::LoadOp::create(builder, loc, rhsAddr);
+ if (auto result = cloneCombiner(lhsVal, rhsVal))
+ fir::StoreOp::create(builder, loc, *result, lhsAddr);
+ }
+ } else {
+ // array allocatable: iterate elements
+ fir::ShapeShiftOp shapeShift =
+ getShapeShift(builder, loc, lhsBox,
+ /*cannotHaveNonDefaultLowerBounds=*/false,
+ /*useDefaultLowerBounds=*/true);
+ hlfir::LoopNest nest =
+ hlfir::genLoopNest(loc, builder, shapeShift.getExtents(),
+ /*isUnordered=*/true);
+ builder.setInsertionPointToStart(nest.body);
+ mlir::Type eleTy = seqTy.getEleTy();
+ mlir::Type refTy = fir::ReferenceType::get(
+ eleTy, fir::isa_volatile_type(eleTy));
+ auto lhsEleAddr = fir::ArrayCoorOp::create(
+ builder, loc, refTy, lhsBox, shapeShift,
+ /*slice=*/mlir::Value{}, nest.oneBasedIndices,
+ /*typeparms=*/mlir::ValueRange{});
+ auto rhsEleAddr = fir::ArrayCoorOp::create(
+ builder, loc, refTy, rhsBox, shapeShift,
+ /*slice=*/mlir::Value{}, nest.oneBasedIndices,
+ /*typeparms=*/mlir::ValueRange{});
+ if (existingCombinerIsByRef) {
+ cloneCombiner(lhsEleAddr, rhsEleAddr);
+ } else {
+ mlir::Value lhsEle =
+ fir::LoadOp::create(builder, loc, lhsEleAddr);
+ mlir::Value rhsEle =
+ fir::LoadOp::create(builder, loc, rhsEleAddr);
+ if (auto result = cloneCombiner(lhsEle, rhsEle))
+ fir::StoreOp::create(builder, loc, *result, lhsEleAddr);
+ }
+ builder.setInsertionPointAfter(nest.outerOp);
+ }
+ mlir::omp::YieldOp::create(builder, loc, op1);
+ };
+
+ OpType decl = createDeclareReductionHelper<OpType>(
+ converter, byrefName, redType, currentLocation, isByRef,
+ genCombinerCB, genInitValueCB);
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ builder.getContext(), decl.getSymName()));
+ ++idx;
+ continue;
+ }
// Custom reductions we can just add to the symbols without
// generating the declare reduction op.
semantics::Symbol *sym = reductionIntrinsic->v.sym();
@@ -866,3 +1046,4 @@ int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
} // namespace omp
} // namespace lower
} // namespace Fortran
+
diff --git a/flang/test/Lower/OpenMP/declare-reduction-allocatable.f90 b/flang/test/Lower/OpenMP/declare-reduction-allocatable.f90
new file mode 100644
index 0000000000000..b0b0b161c3c7c
--- /dev/null
+++ b/flang/test/Lower/OpenMP/declare-reduction-allocatable.f90
@@ -0,0 +1,117 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=45 %s -o - | FileCheck %s
+
+subroutine test_udr_allocatable()
+ implicit none
+ integer :: i
+ integer, allocatable :: a, b(:), c(:,:)
+
+ !$omp declare reduction (foo : integer : omp_out = omp_out + omp_in) &
+ !$omp & initializer (omp_priv = 0)
+
+ allocate(a, b(4), c(3,2))
+ a = 0
+ b = 0
+ c = 0
+
+ !$omp parallel do reduction(foo : a)
+ do i = 1, 10
+ a = a + i
+ end do
+
+ !$omp parallel do reduction(foo : b)
+ do i = 1, 10
+ b = b + i
+ end do
+
+ !$omp parallel do reduction(foo : c)
+ do i = 1, 10
+ c = c + i
+ end do
+end subroutine
+
+! CHECK-LABEL: omp.declare_reduction @foo_byref_box_heap_UxUxi32 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
+! CHECK-SAME: attributes {byref_element_type = !fir.array<?x?xi32>}
+! CHECK: alloc {
+! CHECK: fir.alloca !fir.box<!fir.heap<!fir.array<?x?xi32>>>
+! CHECK: omp.yield
+! CHECK: } init {
+! CHECK: %[[C0_2D:.*]] = arith.constant 0 : i32
+! CHECK: omp.yield
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[ARG0_2D:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, %[[ARG1_2D:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>):
+! CHECK: %[[LHS_BOX_2D:.*]] = fir.load %[[ARG0_2D]]
+! CHECK: %[[RHS_BOX_2D:.*]] = fir.load %[[ARG1_2D]]
+! CHECK: fir.shape_shift
+! CHECK: fir.do_loop {{.*}} unordered {
+! CHECK: fir.do_loop {{.*}} unordered {
+! CHECK: fir.array_coor %[[LHS_BOX_2D]]
+! CHECK: fir.array_coor %[[RHS_BOX_2D]]
+! CHECK: arith.addi
+! CHECK: fir.store
+! CHECK: }
+! CHECK: }
+! CHECK: omp.yield(%[[ARG0_2D]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
+! CHECK: } cleanup {
+! CHECK: fir.freemem
+! CHECK: omp.yield
+
+! CHECK-LABEL: omp.declare_reduction @foo_byref_box_heap_Uxi32 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK-SAME: attributes {byref_element_type = !fir.array<?xi32>}
+! CHECK: alloc {
+! CHECK: fir.alloca !fir.box<!fir.heap<!fir.array<?xi32>>>
+! CHECK: omp.yield
+! CHECK: } init {
+! CHECK: %[[C0_1D:.*]] = arith.constant 0 : i32
+! CHECK: omp.yield
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[ARG0_1D:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, %[[ARG1_1D:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>):
+! CHECK: %[[LHS_BOX_1D:.*]] = fir.load %[[ARG0_1D]]
+! CHECK: %[[RHS_BOX_1D:.*]] = fir.load %[[ARG1_1D]]
+! CHECK: fir.shape_shift
+! CHECK: fir.do_loop {{.*}} unordered {
+! CHECK: fir.array_coor %[[LHS_BOX_1D]]
+! CHECK: fir.array_coor %[[RHS_BOX_1D]]
+! CHECK: arith.addi
+! CHECK: fir.store
+! CHECK: }
+! CHECK: omp.yield(%[[ARG0_1D]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: } cleanup {
+! CHECK: fir.freemem
+! CHECK: omp.yield
+
+! CHECK-LABEL: omp.declare_reduction @foo_byref_box_heap_i32 : !fir.ref<!fir.box<!fir.heap<i32>>>
+! CHECK-SAME: attributes {byref_element_type = i32}
+! CHECK: alloc {
+! CHECK: fir.alloca !fir.box<!fir.heap<i32>>
+! CHECK: omp.yield
+! CHECK: } init {
+! CHECK: %[[C0_S:.*]] = arith.constant 0 : i32
+! CHECK: omp.yield
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[ARG0_S:.*]]: !fir.ref<!fir.box<!fir.heap<i32>>>, %[[ARG1_S:.*]]: !fir.ref<!fir.box<!fir.heap<i32>>>):
+! CHECK: %[[LHS_BOX_S:.*]] = fir.load %[[ARG0_S]]
+! CHECK: %[[RHS_BOX_S:.*]] = fir.load %[[ARG1_S]]
+! CHECK: %[[LHS_ADDR_S:.*]] = fir.box_addr %[[LHS_BOX_S]]
+! CHECK: %[[RHS_ADDR_S:.*]] = fir.box_addr %[[RHS_BOX_S]]
+! CHECK: fir.load %[[LHS_ADDR_S]]
+! CHECK: fir.load %[[RHS_ADDR_S]]
+! CHECK: arith.addi
+! CHECK: fir.store %{{.*}} to %[[LHS_ADDR_S]]
+! CHECK: omp.yield(%[[ARG0_S]] : !fir.ref<!fir.box<!fir.heap<i32>>>)
+! CHECK: } cleanup {
+! CHECK: fir.freemem
+! CHECK: omp.yield
+
+! CHECK-LABEL: omp.declare_reduction @foo : i32
+! CHECK: init {
+! CHECK: %[[C0_BASE:.*]] = arith.constant 0 : i32
+! CHECK: omp.yield(%[[C0_BASE]] : i32)
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[LHS_BASE:.*]]: i32, %[[RHS_BASE:.*]]: i32):
+! CHECK: arith.addi
+! CHECK: omp.yield
+
+! CHECK-LABEL: func.func @_QPtest_udr_allocatable
+! CHECK: omp.wsloop {{.*}} reduction(byref @foo_byref_box_heap_i32 %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.heap<i32>>>)
+! CHECK: omp.wsloop {{.*}} reduction(byref @foo_byref_box_heap_Uxi32 %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: omp.wsloop {{.*}} reduction(byref @foo_byref_box_heap_UxUxi32 %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
>From b045a380311f0e8d9b0edc3c25a4f248234d1329 Mon Sep 17 00:00:00 2001
From: Ritanya-B-Bharadwaj <ritanya.b.bharadwaj at gmail.com>
Date: Mon, 16 Mar 2026 15:51:01 +0530
Subject: [PATCH 2/2] Update ReductionProcessor.cpp
---
flang/lib/Lower/Support/ReductionProcessor.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index 078d8f24d5f7a..57efc982f908e 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -1046,4 +1046,3 @@ int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
} // namespace omp
} // namespace lower
} // namespace Fortran
-
More information about the flang-commits
mailing list