[flang-commits] [flang] [flang][OpenMP] `do concurrent`: support `reduce` on device (PR #156610)
Kareem Ergawy via flang-commits
flang-commits at lists.llvm.org
Mon Sep 22 22:28:55 PDT 2025
https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/156610
>From f698b21b7affbbd664de3461bf92e8d0e783a4e0 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Tue, 2 Sep 2025 08:36:34 -0500
Subject: [PATCH] [flang][OpenMP] `do concurrent`: support `reduce` on device
Extends `do concurrent` to OpenMP device mapping by adding support for
mapping `reduce` specifiers to omp `reduction` clauses. The changes
attach 2 `reduction` clauses to the mapped OpenMP construct: one on the
`teams` part of the construct and one on the `wloop` part.
---
.../OpenMP/DoConcurrentConversion.cpp | 117 ++++++++++--------
.../DoConcurrent/reduce_device.mlir | 53 ++++++++
2 files changed, 121 insertions(+), 49 deletions(-)
create mode 100644 flang/test/Transforms/DoConcurrent/reduce_device.mlir
diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index fb99623128621..03ff16366a9d2 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -141,6 +141,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
for (mlir::Value local : loop.getLocalVars())
liveIns.push_back(local);
+
+ for (mlir::Value reduce : loop.getReduceVars())
+ liveIns.push_back(reduce);
}
/// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -319,7 +322,7 @@ class DoConcurrentConversion
targetOp =
genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns,
targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
- genTeamsOp(doLoop.getLoc(), rewriter);
+ genTeamsOp(rewriter, loop, mapper);
}
mlir::omp::ParallelOp parallelOp =
@@ -492,46 +495,7 @@ class DoConcurrentConversion
if (!mapToDevice)
genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
- if (!loop.getReduceVars().empty()) {
- for (auto [op, byRef, sym, arg] : llvm::zip_equal(
- loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
- loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
- loop.getRegionReduceArgs())) {
- auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
- sym.getLeafReference());
-
- mlir::OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointAfter(firReducer);
- std::string ompReducerName = sym.getLeafReference().str() + ".omp";
-
- auto ompReducer =
- moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
- rewriter.getStringAttr(ompReducerName));
-
- if (!ompReducer) {
- ompReducer = mlir::omp::DeclareReductionOp::create(
- rewriter, firReducer.getLoc(), ompReducerName,
- firReducer.getTypeAttr().getValue());
-
- cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
- ompReducer.getAllocRegion());
- cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
- ompReducer.getInitializerRegion());
- cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
- ompReducer.getReductionRegion());
- cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
- ompReducer.getAtomicReductionRegion());
- cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
- ompReducer.getCleanupRegion());
- moduleSymbolTable.insert(ompReducer);
- }
-
- wsloopClauseOps.reductionVars.push_back(op);
- wsloopClauseOps.reductionByref.push_back(byRef);
- wsloopClauseOps.reductionSyms.push_back(
- mlir::SymbolRefAttr::get(ompReducer));
- }
- }
+ genReductions(rewriter, mapper, loop, wsloopClauseOps);
auto wsloopOp =
mlir::omp::WsloopOp::create(rewriter, loop.getLoc(), wsloopClauseOps);
@@ -553,8 +517,6 @@ class DoConcurrentConversion
rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
mlir::omp::YieldOp::create(rewriter, loop->getLoc());
- loop->getParentOfType<mlir::ModuleOp>().print(
- llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
return {loopNestOp, wsloopOp};
}
@@ -778,15 +740,26 @@ class DoConcurrentConversion
liveInName, shape);
}
- mlir::omp::TeamsOp
- genTeamsOp(mlir::Location loc,
- mlir::ConversionPatternRewriter &rewriter) const {
- auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(
- loc, /*clauses=*/mlir::omp::TeamsOperands{});
+ mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter,
+ fir::DoConcurrentLoopOp loop,
+ mlir::IRMapping &mapper) const {
+ mlir::omp::TeamsOperands teamsOps;
+ genReductions(rewriter, mapper, loop, teamsOps);
+
+ mlir::Location loc = loop.getLoc();
+ auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps);
+ Fortran::common::openmp::EntryBlockArgs teamsArgs;
+ teamsArgs.reduction.vars = teamsOps.reductionVars;
+ Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs,
+ teamsOp.getRegion());
- rewriter.createBlock(&teamsOp.getRegion());
rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+ for (auto [loopVar, teamsArg] : llvm::zip_equal(
+ loop.getReduceVars(), teamsOp.getRegion().getArguments())) {
+ mapper.map(loopVar, teamsArg);
+ }
+
return teamsOp;
}
@@ -861,6 +834,52 @@ class DoConcurrentConversion
}
}
+ void genReductions(mlir::ConversionPatternRewriter &rewriter,
+ mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
+ mlir::omp::ReductionClauseOps &reductionClauseOps) const {
+ if (!loop.getReduceVars().empty()) {
+ for (auto [var, byRef, sym, arg] : llvm::zip_equal(
+ loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
+ loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
+ loop.getRegionReduceArgs())) {
+ auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
+ sym.getLeafReference());
+
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointAfter(firReducer);
+ std::string ompReducerName = sym.getLeafReference().str() + ".omp";
+
+ auto ompReducer =
+ moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
+ rewriter.getStringAttr(ompReducerName));
+
+ if (!ompReducer) {
+ ompReducer = mlir::omp::DeclareReductionOp::create(
+ rewriter, firReducer.getLoc(), ompReducerName,
+ firReducer.getTypeAttr().getValue());
+
+ cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
+ ompReducer.getAllocRegion());
+ cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
+ ompReducer.getInitializerRegion());
+ cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
+ ompReducer.getReductionRegion());
+ cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
+ ompReducer.getAtomicReductionRegion());
+ cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
+ ompReducer.getCleanupRegion());
+ moduleSymbolTable.insert(ompReducer);
+ }
+
+ reductionClauseOps.reductionVars.push_back(
+ mapToDevice ? mapper.lookup(var) : var);
+ reductionClauseOps.reductionByref.push_back(byRef);
+ reductionClauseOps.reductionSyms.push_back(
+ mlir::SymbolRefAttr::get(ompReducer));
+ }
+ }
+ }
+
bool mapToDevice;
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
mlir::SymbolTable &moduleSymbolTable;
diff --git a/flang/test/Transforms/DoConcurrent/reduce_device.mlir b/flang/test/Transforms/DoConcurrent/reduce_device.mlir
new file mode 100644
index 0000000000000..3e46692a15dca
--- /dev/null
+++ b/flang/test/Transforms/DoConcurrent/reduce_device.mlir
@@ -0,0 +1,53 @@
+// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
+
+fir.declare_reduction @add_reduction_f32 : f32 init {
+^bb0(%arg0: f32):
+ %cst = arith.constant 0.000000e+00 : f32
+ fir.yield(%cst : f32)
+} combiner {
+^bb0(%arg0: f32, %arg1: f32):
+ %0 = arith.addf %arg0, %arg1 fastmath<contract> : f32
+ fir.yield(%0 : f32)
+}
+
+func.func @_QPfoo() {
+ %0 = fir.dummy_scope : !fir.dscope
+ %3 = fir.alloca f32 {bindc_name = "s", uniq_name = "_QFfooEs"}
+ %4:2 = hlfir.declare %3 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 1 : index
+ fir.do_concurrent {
+ %7 = fir.alloca i32 {bindc_name = "i"}
+ %8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) reduce(@add_reduction_f32 #fir.reduce_attr<add> %4#0 -> %arg1 : !fir.ref<f32>) {
+ %9 = fir.convert %arg0 : (index) -> i32
+ fir.store %9 to %8#0 : !fir.ref<i32>
+ %10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+ %11 = fir.load %10#0 : !fir.ref<f32>
+ %cst = arith.constant 1.000000e+00 : f32
+ %12 = arith.addf %11, %cst fastmath<contract> : f32
+ hlfir.assign %12 to %10#0 : f32, !fir.ref<f32>
+ }
+ }
+ return
+}
+
+// CHECK: omp.declare_reduction @[[OMP_RED:.*.omp]] : f32
+
+// CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %6 {uniq_name = "_QFfooEs"}
+// CHECK: %[[S_MAP:.*]] = omp.map.info var_ptr(%[[S_DECL]]#1
+
+// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[S_MAP]] -> %[[S_TARGET_ARG:.*]] : {{.*}}) {
+// CHECK: %[[S_DEV_DECL:.*]]:2 = hlfir.declare %[[S_TARGET_ARG]]
+// CHECK: omp.teams reduction(@[[OMP_RED]] %[[S_DEV_DECL]]#0 -> %[[RED_TEAMS_ARG:.*]] : !fir.ref<f32>) {
+// CHECK: omp.parallel {
+// CHECK: omp.distribute {
+// CHECK: omp.wsloop reduction(@[[OMP_RED]] %[[RED_TEAMS_ARG]] -> %[[RED_WS_ARG:.*]] : {{.*}}) {
+// CHECK: %[[S_WS_DECL:.*]]:2 = hlfir.declare %[[RED_WS_ARG]] {uniq_name = "_QFfooEs"}
+// CHECK: %[[S_VAL:.*]] = fir.load %[[S_WS_DECL]]#0
+// CHECK: %[[RED_RES:.*]] = arith.addf %[[S_VAL]], %{{.*}} fastmath<contract> : f32
+// CHECK: hlfir.assign %[[RED_RES]] to %[[S_WS_DECL]]#0
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
More information about the flang-commits
mailing list