[Mlir-commits] [flang] [mlir] [mlir][flang][openmp] Rework parallel reduction operations (PR #79308)
Kiran Chandramohan
llvmlistbot at llvm.org
Wed Feb 7 04:58:10 PST 2024
https://github.com/kiranchandramohan updated https://github.com/llvm/llvm-project/pull/79308
>From 10f98074cca27884a096ef679779a6d3a8b57205 Mon Sep 17 00:00:00 2001
From: David Truby <david at truby.dev>
Date: Wed, 24 Jan 2024 15:31:15 +0000
Subject: [PATCH 1/2] [mlir][flang][openmp] Rework parallel reduction
operations
This patch reworks the way that parallel reduction operations function to better
match the expected semantics from the OpenMP specification. Previously specific
omp.reduction operations were used inside the region, meaning that the reduction
only applied when the correct operation was used, whereas the specification
states that any change to the variable inside the region should be taken into
account for the reduction.
The new semantics create a private reduction variable as a block argument which
should be used normally for all operations on that variable in the region; this
private variable is then combined with the others into the shared variable. This
way no special omp.reduction operations are needed inside the region.
This patch only makes the change for the `parallel` operation, the change for
the `wsloop` operation will be in a separate patch.
---
flang/lib/Lower/OpenMP.cpp | 92 +++++++++++++------
.../OpenMP/FIR/parallel-reduction-add.f90 | 26 ++++--
.../Lower/OpenMP/parallel-reduction-add.f90 | 26 ++++--
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 9 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 68 ++++++++++++++
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 +++-
mlir/test/Dialect/OpenMP/ops.mlir | 39 ++++----
mlir/test/Target/LLVMIR/openmp-reduction.mlir | 12 ++-
8 files changed, 213 insertions(+), 75 deletions(-)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index be2117efbabc0a..fcf10b26c135b4 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -621,10 +621,12 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*mapSymbols = nullptr) const;
- bool processReduction(
- mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const;
+ bool
+ processReduction(mlir::Location currentLocation,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ *reductionSymbols = nullptr) const;
bool processSectionsReduction(mlir::Location currentLocation) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
@@ -1075,12 +1077,14 @@ class ReductionProcessor {
/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
- static void addReductionDecl(
- mlir::Location currentLocation,
- Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpReductionClause &reduction,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
+ static void
+ addReductionDecl(mlir::Location currentLocation,
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpReductionClause &reduction,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ *reductionSymbols = nullptr) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::omp::ReductionDeclareOp decl;
const auto &redOperator{
@@ -1110,6 +1114,8 @@ class ReductionProcessor {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
@@ -1142,6 +1148,8 @@ class ReductionProcessor {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
@@ -1935,13 +1943,16 @@ bool ClauseProcessor::processMap(
bool ClauseProcessor::processReduction(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
+ const {
return findRepeatableClause<ClauseTy::Reduction>(
[&](const ClauseTy::Reduction *reductionClause,
const Fortran::parser::CharBlock &) {
ReductionProcessor rp;
rp.addReductionDecl(currentLocation, converter, reductionClause->v,
- reductionVars, reductionDeclSymbols);
+ reductionVars, reductionDeclSymbols,
+ reductionSymbols);
});
}
@@ -2250,8 +2261,11 @@ static void createBodyOfOp(
Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
Fortran::lower::pft::Evaluation &eval, bool genNested,
const Fortran::parser::OmpClauseList *clauses = nullptr,
- const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
- bool outerCombined = false, DataSharingProcessor *dsp = nullptr) {
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &loopArgs = {},
+ bool outerCombined = false, DataSharingProcessor *dsp = nullptr,
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionArgs =
+ {},
+ const llvm::SmallVector<mlir::Type> &reductionTypes = {}) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
auto insertMarker = [](fir::FirOpBuilder &builder) {
@@ -2264,24 +2278,32 @@ static void createBodyOfOp(
// argument. Also update the symbol's address with the mlir argument value.
// e.g. For loops the argument is the induction variable. And all further
// uses of the induction variable should use this mlir value.
- if (args.size()) {
+ if (loopArgs.size()) {
std::size_t loopVarTypeSize = 0;
- for (const Fortran::semantics::Symbol *arg : args)
+ for (const Fortran::semantics::Symbol *arg : loopArgs)
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
- llvm::SmallVector<mlir::Location> locs(args.size(), loc);
+ llvm::SmallVector<mlir::Type> tiv(loopArgs.size(), loopVarType);
+ llvm::SmallVector<mlir::Location> locs(loopArgs.size(), loc);
firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
// The argument is not currently in memory, so make a temporary for the
// argument, and store it there, then bind that location to the argument.
mlir::Operation *storeOp = nullptr;
- for (auto [argIndex, argSymbol] : llvm::enumerate(args)) {
+ for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
mlir::Value indexVal =
fir::getBase(op.getRegion().front().getArgument(argIndex));
storeOp =
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
}
firOpBuilder.setInsertionPointAfter(storeOp);
+ } else if (reductionArgs.size()) {
+ llvm::SmallVector<mlir::Location> locs(reductionArgs.size(), loc);
+ auto block =
+ firOpBuilder.createBlock(&op.getRegion(), {}, reductionTypes, locs);
+ for (auto [arg, prv] :
+ llvm::zip_equal(reductionArgs, block->getArguments())) {
+ converter.bindSymbol(*arg, prv);
+ }
} else {
firOpBuilder.createBlock(&op.getRegion());
}
@@ -2382,8 +2404,8 @@ static void createBodyOfOp(
assert(tempDsp.has_value());
tempDsp->processStep2(op, isLoop);
} else {
- if (isLoop && args.size() > 0)
- dsp->setLoopIV(converter.getSymbolAddress(*args[0]));
+ if (isLoop && loopArgs.size() > 0)
+ dsp->setLoopIV(converter.getSymbolAddress(*loopArgs[0]));
dsp->processStep2(op, isLoop);
}
}
@@ -2468,7 +2490,8 @@ static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
currentLocation, std::forward<Args>(args)...);
createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
clauseList,
- /*args=*/{}, outerCombined);
+ /*loopArgs=*/{}, outerCombined, /*dsp=*/nullptr,
+ /*reductionArgs=*/{}, /*reductionTypes=*/{});
return op;
}
@@ -2505,6 +2528,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
reductionVars;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
ClauseProcessor cp(converter, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2514,11 +2538,11 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
cp.processDefault();
cp.processAllocate(allocatorOperands, allocateOperands);
if (!outerCombined)
- cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
+ cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
+ &reductionSymbols);
- return genOpWithBody<mlir::omp::ParallelOp>(
- converter, eval, genNested, currentLocation, outerCombined, &clauseList,
- /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
+ auto op = converter.getFirOpBuilder().create<mlir::omp::ParallelOp>(
+ currentLocation, mlir::TypeRange(), ifClauseOperand,
numThreadsClauseOperand, allocateOperands, allocatorOperands,
reductionVars,
reductionDeclSymbols.empty()
@@ -2526,6 +2550,17 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr);
+
+ llvm::SmallVector<mlir::Type> reductionTypes;
+ reductionTypes.reserve(reductionVars.size());
+ llvm::transform(reductionVars, std::back_inserter(reductionTypes),
+ [](mlir::Value v) { return v.getType(); });
+ createBodyOfOp<mlir::omp::ParallelOp>(op, converter, currentLocation, eval,
+ genNested, &clauseList, /*loopArgs=*/{},
+ outerCombined, /*dsp=*/nullptr,
+ reductionSymbols, reductionTypes);
+
+ return op;
}
static mlir::omp::SectionOp
@@ -3517,10 +3552,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
break;
}
- if (singleDirective) {
- genOpenMPReduction(converter, beginClauseList);
+ if (singleDirective)
return;
- }
// Codegen for combined directives
bool combinedDirective = false;
@@ -3556,7 +3589,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
")");
genNestedEvaluations(converter, eval);
- genOpenMPReduction(converter, beginClauseList);
}
static void
diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90
index 6580aeb13ccd1e..4b223e822760a9 100644
--- a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90
+++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90
@@ -27,9 +27,11 @@
!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
-!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
-!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
-!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
+!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
+!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32
+!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
+!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -48,9 +50,11 @@ subroutine simple_int_add
!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
-!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
-!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
+!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
+!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
+!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -72,11 +76,15 @@ subroutine simple_real_add
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
-!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
+!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
+!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
+!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
+!CHECK: fir.store %[[RES1]] to %[[PRV1]]
+!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
-!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
+!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
+!CHECK: fir.store %[[RES0]] to %[[PRV0]]
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add.f90
index 81a93aebbd2661..8f3ac3dc357af4 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-add.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-add.f90
@@ -28,9 +28,12 @@
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>) {
+!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
-!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -50,9 +53,12 @@ subroutine simple_int_add
!CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
-!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
+!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
+!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -76,11 +82,17 @@ subroutine simple_real_add
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
+!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
+!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
+!CHECK: %[[R_LPRV:.+]] = fir.load %[[RP_DECL]]#0 : !fir.ref<f32>
+!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32
+!CHECK: hlfir.assign %[[RES1]] to %[[RP_DECL]]#0 : f32, !fir.ref<f32>
+!CHECK: %[[I_LPRV:.+]] = fir.load %[[IP_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
-!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32
+!CHECK: hlfir.assign %[[RES0]] to %[[IP_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 96c15e775a3024..77bea9db5276e1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -200,11 +200,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
unsigned getNumReductionVars() { return getReductionVars().size(); }
}];
let assemblyFormat = [{
- oilist( `reduction` `(`
- custom<ReductionVarList>(
- $reduction_vars, type($reduction_vars), $reductions
- ) `)`
- | `if` `(` $if_expr_var `:` type($if_expr_var) `)`
+ oilist(
+ `if` `(` $if_expr_var `:` type($if_expr_var) `)`
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
| `allocate` `(`
custom<AllocateAndAllocator>(
@@ -212,7 +209,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
- ) $region attr-dict
+ ) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
}];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 13cc16125a2733..26cdf7d3aa98e0 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -21,6 +21,7 @@
#include "mlir/Interfaces/FoldInterfaces.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
@@ -34,6 +35,7 @@
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
+#include "mlir/Support/LogicalResult.h"
using namespace mlir;
using namespace mlir::omp;
@@ -427,6 +429,71 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//
+ParseResult
+parseReductionClause(OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
+ SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
+ SmallVectorImpl<OpAsmParser::Argument> &privates) {
+ if (failed(parser.parseOptionalKeyword("reduction")))
+ return failure();
+
+ SmallVector<SymbolRefAttr> reductionVec;
+
+ if (failed(
+ parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
+ if (parser.parseAttribute(reductionVec.emplace_back()) ||
+ parser.parseOperand(operands.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseArgument(privates.emplace_back()) ||
+ parser.parseColonType(types.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+
+ for (auto [prv, type] : llvm::zip_equal(privates, types)) {
+ prv.type = type;
+ }
+ SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
+ reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
+ return success();
+}
+
+static void printReductionClause(OpAsmPrinter &p, Operation *op, Region ®ion,
+ ValueRange operands, TypeRange types,
+ ArrayAttr reductionSymbols) {
+ p << "reduction(";
+ llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
+ region.front().getArguments(), types),
+ p, [&p](auto t) {
+ auto [sym, op, arg, type] = t;
+ p << sym << " " << op << " -> " << arg << " : "
+ << type;
+ });
+ p << ") ";
+}
+
+static ParseResult
+parseParallelRegion(OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
+ SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
+
+ llvm::SmallVector<OpAsmParser::Argument> privates;
+ if (succeeded(parseReductionClause(parser, region, operands, types,
+ reductionSymbols, privates)))
+ return parser.parseRegion(region, privates);
+
+ return parser.parseRegion(region);
+}
+
+static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
+ ValueRange operands, TypeRange types,
+ ArrayAttr reductionSymbols) {
+ if (reductionSymbols)
+ printReductionClause(p, op, region, operands, types, reductionSymbols);
+ p.printRegion(region, /*printEntryBlockArgs=*/false);
+}
+
/// reduction-entry-list ::= reduction-entry
/// | reduction-entry-list `,` reduction-entry
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
@@ -1114,6 +1181,7 @@ parseLoopControl(OpAsmParser &parser, Region ®ion,
loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
for (auto &iv : ivs)
iv.type = loopVarType;
+
return parser.parseRegion(region, ivs);
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 23e101f1e45272..71b7937671b26e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1018,9 +1018,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
// Allocate reduction vars
SmallVector<llvm::Value *> privateReductionVariables;
DenseMap<Value, llvm::Value *> reductionVariableMap;
- allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
- reductionDecls, privateReductionVariables,
- reductionVariableMap);
+ {
+ llvm::IRBuilderBase::InsertPointGuard guard(builder);
+ builder.restoreIP(allocaIP);
+ auto args = opInst.getRegion().getArguments();
+
+ for (std::size_t i = 0; i < opInst.getNumReductionVars(); ++i) {
+ llvm::Value *var = builder.CreateAlloca(
+ moduleTranslation.convertType(reductionDecls[i].getType()));
+ moduleTranslation.mapValue(args[i], var);
+ privateReductionVariables.push_back(var);
+ reductionVariableMap.try_emplace(opInst.getReductionVars()[i], var);
+ }
+ }
// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index ccf72ae31d439e..0451f4a0bfa234 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -640,11 +640,13 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
func.func @parallel_reduction() {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
- // CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr)
- omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
+ // CHECK: omp.parallel reduction(@add_f32 {{.+}} -> {{.+}} : !llvm.ptr)
+ omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
%1 = arith.constant 2.0 : f32
- // CHECK: omp.reduction %{{.+}}, %{{.+}}
- omp.reduction %1, %0 : f32, !llvm.ptr
+ %2 = llvm.load %prv : !llvm.ptr -> f32
+ // CHECK: llvm.fadd %{{.*}}, %{{.*}} : f32
+ %3 = llvm.fadd %1, %2 : f32
+ llvm.store %3, %prv : f32, !llvm.ptr
omp.terminator
}
return
@@ -654,13 +656,14 @@ func.func @parallel_reduction() {
func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
- // CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
- omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
+ // CHECK: omp.parallel reduction(@add_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) {
+ omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
// CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
%1 = arith.constant 2.0 : f32
- // CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr
- omp.reduction %1, %0 : f32, !llvm.ptr
+ %2 = llvm.load %prv : !llvm.ptr -> f32
+ // CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32
+ llvm.fadd %1, %2 : f32
// CHECK: omp.yield
omp.yield
}
@@ -799,11 +802,14 @@ func.func @wsloop_reduction2(%lb : index, %ub : index, %step : index) {
// CHECK-LABEL: func @parallel_reduction2
func.func @parallel_reduction2() {
%0 = memref.alloca() : memref<1xf32>
- // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
- omp.parallel reduction(@add2_f32 -> %0 : memref<1xf32>) {
+ // CHECK: omp.parallel reduction(@add2_f32 %{{.+}} -> %{{.+}} : memref<1xf32>)
+ omp.parallel reduction(@add2_f32 %0 -> %prv : memref<1xf32>) {
%1 = arith.constant 2.0 : f32
- // CHECK: omp.reduction
- omp.reduction %1, %0 : f32, memref<1xf32>
+ %2 = arith.constant 0 : index
+ %3 = memref.load %prv[%2] : memref<1xf32>
+ // CHECK: llvm.fadd
+ %4 = llvm.fadd %1, %3 : f32
+ memref.store %4, %prv[%2] : memref<1xf32>
omp.terminator
}
return
@@ -813,13 +819,14 @@ func.func @parallel_reduction2() {
func.func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
- // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : !llvm.ptr) {
- omp.parallel reduction(@add2_f32 -> %0 : !llvm.ptr) {
+ // CHECK: omp.parallel reduction(@add2_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) {
+ omp.parallel reduction(@add2_f32 %0 -> %prv : !llvm.ptr) {
// CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
%1 = arith.constant 2.0 : f32
- // CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr
- omp.reduction %1, %0 : f32, !llvm.ptr
+ %2 = llvm.load %prv : !llvm.ptr -> f32
+ // CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32
+ %3 = llvm.fadd %1, %2 : f32
// CHECK: omp.yield
omp.yield
}
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-reduction.mlir
index 93ab578df9e4e8..dae83c0cf92ed8 100644
--- a/mlir/test/Target/LLVMIR/openmp-reduction.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-reduction.mlir
@@ -441,9 +441,11 @@ atomic {
llvm.func @simple_reduction_parallel() {
%c1 = llvm.mlir.constant(1 : i32) : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
- omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
+ omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
%1 = llvm.mlir.constant(2.0 : f32) : f32
- omp.reduction %1, %0 : f32, !llvm.ptr
+ %2 = llvm.load %prv : !llvm.ptr -> f32
+ %3 = llvm.fadd %2, %1 : f32
+ llvm.store %3, %prv : f32, !llvm.ptr
omp.terminator
}
llvm.return
@@ -512,10 +514,12 @@ llvm.func @parallel_nested_workshare_reduction(%ub : i64) {
%lb = llvm.mlir.constant(1 : i64) : i64
%step = llvm.mlir.constant(1 : i64) : i64
- omp.parallel reduction(@add_i32 -> %0 : !llvm.ptr) {
+ omp.parallel reduction(@add_i32 %0 -> %prv : !llvm.ptr) {
omp.wsloop for (%iv) : i64 = (%lb) to (%ub) step (%step) {
%ival = llvm.trunc %iv : i64 to i32
- omp.reduction %ival, %0 : i32, !llvm.ptr
+ %lprv = llvm.load %prv : !llvm.ptr -> i32
+ %add = llvm.add %lprv, %ival : i32
+ llvm.store %add, %prv : i32, !llvm.ptr
omp.yield
}
omp.terminator
>From 1ed4bc39de92defb5d43df3c7ef69e0b53e80d30 Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Wed, 7 Feb 2024 12:52:43 +0000
Subject: [PATCH 2/2] [Flang][OpenMP] Add new test to demonstrate reduction
---
.../test/Lower/OpenMP/parallel-reduction.f90 | 38 +++++++++++++++++++
1 file changed, 38 insertions(+)
create mode 100644 flang/test/Lower/OpenMP/parallel-reduction.f90
diff --git a/flang/test/Lower/OpenMP/parallel-reduction.f90 b/flang/test/Lower/OpenMP/parallel-reduction.f90
new file mode 100644
index 00000000000000..a07d118b0ba19a
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction.f90
@@ -0,0 +1,38 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK: omp.reduction.declare @[[REDUCTION_DECLARE:[_a-z0-9]+]] : i32 init {
+!CHECK: ^bb0(%{{.*}}: i32):
+!CHECK: %[[I0:[_a-z0-9]+]] = arith.constant 0 : i32
+!CHECK: omp.yield(%[[I0]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[C0:[_a-z0-9]+]]: i32, %[[C1:[_a-z0-9]+]]: i32):
+!CHECK: %[[CR:[_a-z0-9]+]] = arith.addi %[[C0]], %[[C1]] : i32
+!CHECK: omp.yield(%[[CR]] : i32)
+!CHECK: }
+!CHECK: func.func @_QQmain() attributes {fir.bindc_name = "mn"} {
+!CHECK: %[[RED_ACCUM_REF:[_a-z0-9]+]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
+!CHECK: %[[RED_ACCUM_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[RED_ACCUM_REF]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[C0:[_a-z0-9]+]] = arith.constant 0 : i32
+!CHECK: hlfir.assign %[[C0]] to %[[RED_ACCUM_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK: omp.parallel reduction(@[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
+!CHECK: %[[PRIVATE_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[PRIVATE_RED]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[C1:[_a-z0-9]+]] = arith.constant 1 : i32
+!CHECK: hlfir.assign %[[C1]] to %[[PRIVATE_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: %[[RED_ACCUM_VAL:[_a-z0-9]+]] = fir.load %[[RED_ACCUM_DECL]]#0 : !fir.ref<i32>
+!CHECK: {{.*}} = fir.call @_FortranAioOutputInteger32(%{{.*}}, %[[RED_ACCUM_VAL]]) fastmath<contract> : (!fir.ref<i8>, i32) -> i1
+!CHECK: return
+!CHECK: }
+
+program mn
+ integer :: i
+ i = 0
+
+ !$omp parallel reduction(+:i)
+ i = 1
+ !$omp end parallel
+
+ print *, i
+end program
More information about the Mlir-commits
mailing list