[flang-commits] [flang] [Flang][OpenMP]Support for lowering task_reduction and in_reduction to MLIR (PR #111155)
Kaviya Rajendiran via flang-commits
flang-commits at lists.llvm.org
Fri Dec 13 03:58:51 PST 2024
https://github.com/kaviya2510 updated https://github.com/llvm/llvm-project/pull/111155
>From 60cbcc29d9d0628db19e498377759b6affb2b2b5 Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Fri, 6 Dec 2024 18:40:03 +0530
Subject: [PATCH 1/2] [Flang][OpenMP]Support for lowering task_reduction and
in_reduction to MLIR
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 55 +++++++++++++-
flang/lib/Lower/OpenMP/ClauseProcessor.h | 7 ++
.../lib/Lower/OpenMP/DataSharingProcessor.cpp | 7 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 71 +++++++++++++------
flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 41 +++++++++--
flang/lib/Lower/OpenMP/ReductionProcessor.h | 4 +-
.../Lower/OpenMP/Todo/task-inreduction.f90 | 15 ----
.../OpenMP/Todo/taskgroup-task-reduction.f90 | 10 ---
flang/test/Lower/OpenMP/task-inreduction.f90 | 35 +++++++++
.../OpenMP/taskgroup-task-array-reduction.f90 | 49 +++++++++++++
.../OpenMP/taskgroup-task_reduction01.f90 | 34 +++++++++
.../OpenMP/taskgroup-task_reduction02.f90 | 36 ++++++++++
12 files changed, 305 insertions(+), 59 deletions(-)
delete mode 100644 flang/test/Lower/OpenMP/Todo/task-inreduction.f90
delete mode 100644 flang/test/Lower/OpenMP/Todo/taskgroup-task-reduction.f90
create mode 100644 flang/test/Lower/OpenMP/task-inreduction.f90
create mode 100644 flang/test/Lower/OpenMP/taskgroup-task-array-reduction.f90
create mode 100644 flang/test/Lower/OpenMP/taskgroup-task_reduction01.f90
create mode 100644 flang/test/Lower/OpenMP/taskgroup-task_reduction02.f90
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 48c559a78b9bc4..1f94458ff0b976 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -916,6 +916,30 @@ bool ClauseProcessor::processIsDevicePtr(
});
}
+bool ClauseProcessor::processInReduction(
+ mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
+ return findRepeatableClause<omp::clause::InReduction>(
+ [&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
+ llvm::SmallVector<mlir::Value> inReductionVars;
+ llvm::SmallVector<bool> inReduceVarByRef;
+ llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
+ llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
+ ReductionProcessor rp;
+ rp.addDeclareReduction<omp::clause::InReduction>(
+ currentLocation, converter, clause, inReductionVars,
+ inReduceVarByRef, inReductionDeclSymbols, inReductionSyms);
+
+ // Copy local lists into the output.
+ llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
+ llvm::copy(inReduceVarByRef,
+ std::back_inserter(result.inReductionByref));
+ llvm::copy(inReductionDeclSymbols,
+ std::back_inserter(result.inReductionSyms));
+ llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
+ });
+}
+
bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::Link>(
@@ -1126,9 +1150,10 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
- rp.addDeclareReduction(currentLocation, converter, clause,
- reductionVars, reduceVarByRef,
- reductionDeclSymbols, reductionSyms);
+
+ rp.addDeclareReduction<omp::clause::Reduction>(
+ currentLocation, converter, clause, reductionVars, reduceVarByRef,
+ reductionDeclSymbols, reductionSyms);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
@@ -1139,6 +1164,30 @@ bool ClauseProcessor::processReduction(
});
}
+bool ClauseProcessor::processTaskReduction(
+ mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
+ return findRepeatableClause<omp::clause::TaskReduction>(
+ [&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
+ llvm::SmallVector<mlir::Value> taskReductionVars;
+ llvm::SmallVector<bool> TaskReduceVarByRef;
+ llvm::SmallVector<mlir::Attribute> TaskReductionDeclSymbols;
+ llvm::SmallVector<const semantics::Symbol *> TaskReductionSyms;
+ ReductionProcessor rp;
+ rp.addDeclareReduction<omp::clause::TaskReduction>(
+ currentLocation, converter, clause, taskReductionVars,
+ TaskReduceVarByRef, TaskReductionDeclSymbols, TaskReductionSyms);
+ // Copy local lists into the output.
+ llvm::copy(taskReductionVars,
+ std::back_inserter(result.taskReductionVars));
+ llvm::copy(TaskReduceVarByRef,
+ std::back_inserter(result.taskReductionByref));
+ llvm::copy(TaskReductionDeclSymbols,
+ std::back_inserter(result.taskReductionSyms));
+ llvm::copy(TaskReductionSyms, std::back_inserter(outReductionSyms));
+ });
+}
+
bool ClauseProcessor::processTo(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::To>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index e0fe917c50e8f8..e042d3a1efdc82 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -105,6 +105,9 @@ class ClauseProcessor {
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+ bool processInReduction(
+ mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
@@ -123,6 +126,10 @@ class ClauseProcessor {
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
+ bool processTaskReduction(mlir::Location currentLocation,
+ mlir::omp::TaskReductionClauseOps &result,
+ llvm::SmallVectorImpl<const semantics::Symbol *>
+ &TaskReductionSyms) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 99835c515463b9..d4377498ccad04 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -344,8 +344,13 @@ void DataSharingProcessor::collectSymbols(
// Collect all symbols referenced in the evaluation being processed,
// that matches 'flag'.
llvm::SetVector<const semantics::Symbol *> allSymbols;
+ bool collectSymbols = true;
+ for (const omp::Clause &clause : clauses) {
+ if (clause.id == llvm::omp::Clause::OMPC_in_reduction)
+ collectSymbols = false;
+ }
converter.collectSymbolSet(eval, allSymbols, flag,
- /*collectSymbols=*/true,
+ /*collectSymbols=*/collectSymbols,
/*collectHostAssociatedSymbols=*/true);
llvm::SetVector<const semantics::Symbol *> symbolsInNestedRegions;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c167d347b43159..f657a2ef0a26d1 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1249,34 +1249,35 @@ static void genTargetEnterExitUpdateDataClauses(
cp.processNowait(clauseOps);
}
-static void genTaskClauses(lower::AbstractConverter &converter,
- semantics::SemanticsContext &semaCtx,
- lower::StatementContext &stmtCtx,
- const List<Clause> &clauses, mlir::Location loc,
- mlir::omp::TaskOperands &clauseOps) {
+static void genTaskClauses(
+ lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+ lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processDepend(clauseOps);
cp.processFinal(stmtCtx, clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
+ cp.processInReduction(loc, clauseOps, InReductionSyms);
cp.processMergeable(clauseOps);
cp.processPriority(stmtCtx, clauseOps);
cp.processUntied(clauseOps);
cp.processDetach(clauseOps);
// TODO Support delayed privatization.
- cp.processTODO<clause::Affinity, clause::InReduction>(
+ cp.processTODO<clause::Affinity>(
loc, llvm::omp::Directive::OMPD_task);
}
-static void genTaskgroupClauses(lower::AbstractConverter &converter,
- semantics::SemanticsContext &semaCtx,
- const List<Clause> &clauses, mlir::Location loc,
- mlir::omp::TaskgroupOperands &clauseOps) {
+static void genTaskgroupClauses(
+ lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ mlir::omp::TaskgroupOperands &clauseOps,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
- cp.processTODO<clause::TaskReduction>(loc,
- llvm::omp::Directive::OMPD_taskgroup);
+ cp.processTaskReduction(loc, clauseOps, taskReductionSyms);
}
static void genTaskwaitClauses(lower::AbstractConverter &converter,
@@ -1887,7 +1888,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::const_iterator item) {
lower::StatementContext stmtCtx;
mlir::omp::TaskOperands clauseOps;
- genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
+ llvm::SmallVector<const semantics::Symbol *> InReductionSyms;
+ genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
+ InReductionSyms);
if (!enableDelayedPrivatization)
return genOpWithBody<mlir::omp::TaskOp>(
@@ -1904,22 +1907,35 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
EntryBlockArgs taskArgs;
taskArgs.priv.syms = dsp.getDelayedPrivSymbols();
taskArgs.priv.vars = clauseOps.privateVars;
+ taskArgs.inReduction.syms = InReductionSyms;
+ taskArgs.inReduction.vars = clauseOps.inReductionVars;
auto genRegionEntryCB = [&](mlir::Operation *op) {
genEntryBlock(converter.getFirOpBuilder(), taskArgs, op->getRegion(0));
bindEntryBlockArgs(converter,
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
taskArgs);
- return llvm::to_vector(taskArgs.priv.syms);
+ return llvm::to_vector(taskArgs.getSyms());
};
- return genOpWithBody<mlir::omp::TaskOp>(
+ OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_task)
.setClauses(&item->clauses)
.setDataSharingProcessor(&dsp)
- .setGenRegionEntryCb(genRegionEntryCB),
- queue, item, clauseOps);
+ .setGenRegionEntryCb(genRegionEntryCB);
+
+ auto taskOp =
+ genOpWithBody<mlir::omp::TaskOp>(genInfo, queue, item, clauseOps);
+
+ llvm::SmallVector<mlir::Type> inReductionTypes;
+ for (const auto &inreductionVar : clauseOps.inReductionVars)
+ inReductionTypes.push_back(inreductionVar.getType());
+
+ // Add reduction variables as entry block arguments to the task region
+ llvm::SmallVector<mlir::Location> blockArgLocs(InReductionSyms.size(), loc);
+ taskOp->getRegion(0).addArguments(inReductionTypes, blockArgLocs);
+ return taskOp;
}
static mlir::omp::TaskgroupOp
@@ -1929,13 +1945,26 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
mlir::omp::TaskgroupOperands clauseOps;
- genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps);
+ llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
+ genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
+ taskReductionSyms);
- return genOpWithBody<mlir::omp::TaskgroupOp>(
+ OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_taskgroup)
- .setClauses(&item->clauses),
- queue, item, clauseOps);
+ .setClauses(&item->clauses);
+
+ auto taskgroupOp =
+ genOpWithBody<mlir::omp::TaskgroupOp>(genInfo, queue, item, clauseOps);
+
+ llvm::SmallVector<mlir::Type> taskReductionTypes;
+ for (const auto &taskreductionVar : clauseOps.taskReductionVars)
+ taskReductionTypes.push_back(taskreductionVar.getType());
+
+ // Add reduction variables as entry block arguments to the taskgroup region
+ llvm::SmallVector<mlir::Location> blockArgLocs(taskReductionSyms.size(), loc);
+ taskgroupOp->getRegion(0).addArguments(taskReductionTypes, blockArgLocs);
+ return taskgroupOp;
}
static mlir::omp::TaskwaitOp
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 736de2ee511bef..4bdfda701a9c88 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -24,6 +24,7 @@
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
+#include <type_traits>
static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
@@ -34,6 +35,32 @@ namespace Fortran {
namespace lower {
namespace omp {
+// explicit template declarations
+template void ReductionProcessor::addDeclareReduction<omp::clause::Reduction>(
+ mlir::Location currentLocation, lower::AbstractConverter &converter,
+ const omp::clause::Reduction &reduction,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<bool> &reduceVarByRef,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
+
+template void
+ReductionProcessor::addDeclareReduction<omp::clause::TaskReduction>(
+ mlir::Location currentLocation, lower::AbstractConverter &converter,
+ const omp::clause::TaskReduction &reduction,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<bool> &reduceVarByRef,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
+
+template void ReductionProcessor::addDeclareReduction<omp::clause::InReduction>(
+ mlir::Location currentLocation, lower::AbstractConverter &converter,
+ const omp::clause::InReduction &reduction,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<bool> &reduceVarByRef,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
+
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
@@ -716,22 +743,22 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}
+template <class T>
void ReductionProcessor::addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
- const omp::clause::Reduction &reduction,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
- reduction.t))
- TODO(currentLocation, "Reduction modifiers are not supported");
+ if constexpr (std::is_same<T, omp::clause::Reduction>::value) {
+ if (std::get<std::optional<typename T::ReductionModifier>>(reduction.t))
+ TODO(currentLocation, "Reduction modifiers are not supported");
+ }
mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
- std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
+ std::get<typename T::ReductionIdentifiers>(reduction.t)};
assert(redOperatorList.size() == 1 && "Expecting single operator");
const auto &redOperator = redOperatorList.front();
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 5f4d742b62cb10..91b54da314243d 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -120,10 +120,10 @@ class ReductionProcessor {
/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
+ template <class T>
static void addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
- const omp::clause::Reduction &reduction,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
diff --git a/flang/test/Lower/OpenMP/Todo/task-inreduction.f90 b/flang/test/Lower/OpenMP/Todo/task-inreduction.f90
deleted file mode 100644
index aeed680a6dba7c..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/task-inreduction.f90
+++ /dev/null
@@ -1,15 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-!===============================================================================
-! `mergeable` clause
-!===============================================================================
-
-! CHECK: not yet implemented: Unhandled clause IN_REDUCTION in TASK construct
-subroutine omp_task_in_reduction()
- integer i
- i = 0
- !$omp task in_reduction(+:i)
- i = i + 1
- !$omp end task
-end subroutine omp_task_in_reduction
diff --git a/flang/test/Lower/OpenMP/Todo/taskgroup-task-reduction.f90 b/flang/test/Lower/OpenMP/Todo/taskgroup-task-reduction.f90
deleted file mode 100644
index 1cb471d784d766..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/taskgroup-task-reduction.f90
+++ /dev/null
@@ -1,10 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s -fopenmp-version=50 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s -fopenmp-version=50 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Unhandled clause TASK_REDUCTION in TASKGROUP construct
-subroutine omp_taskgroup_task_reduction
- integer :: res
- !$omp taskgroup task_reduction(+:res)
- res = res + 1
- !$omp end taskgroup
-end subroutine omp_taskgroup_task_reduction
diff --git a/flang/test/Lower/OpenMP/task-inreduction.f90 b/flang/test/Lower/OpenMP/task-inreduction.f90
new file mode 100644
index 00000000000000..ded4710d5c13d6
--- /dev/null
+++ b/flang/test/Lower/OpenMP/task-inreduction.f90
@@ -0,0 +1,35 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
+!CHECK: ^bb0(%{{.*}}: i32):
+!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK: omp.yield(%[[C0_1]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
+!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+!CHECK: omp.yield(%[[RES]] : i32)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPomp_task_in_reduction() {
+! [...]
+!CHECK: omp.task in_reduction(@[[RED_I32_NAME]] %[[VAL_1:.*]]#0 -> %[[ARG0]] : !fir.ref<i32>) {
+!CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[ARG0]]
+!CHECK-SAME: {uniq_name = "_QFomp_task_in_reductionEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<i32>
+!CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
+!CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
+!CHECK: hlfir.assign %[[VAL_7]] to %[[VAL_4]]#0 : i32, !fir.ref<i32>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+!CHECK: }
+
+subroutine omp_task_in_reduction()
+ integer i
+ i = 0
+ !$omp task in_reduction(+:i)
+ i = i + 1
+ !$omp end task
+end subroutine omp_task_in_reduction
\ No newline at end of file
diff --git a/flang/test/Lower/OpenMP/taskgroup-task-array-reduction.f90 b/flang/test/Lower/OpenMP/taskgroup-task-array-reduction.f90
new file mode 100644
index 00000000000000..7e6d7f09fbc679
--- /dev/null
+++ b/flang/test/Lower/OpenMP/taskgroup-task-array-reduction.f90
@@ -0,0 +1,49 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_Uxf32 : !fir.ref<!fir.box<!fir.array<?xf32>>> alloc {
+! [...]
+! CHECK: omp.yield
+! CHECK-LABEL: } init {
+! [...]
+! CHECK: omp.yield
+! CHECK-LABEL: } combiner {
+! [...]
+! CHECK: omp.yield
+! CHECK-LABEL: } cleanup {
+! [...]
+! CHECK: omp.yield
+! CHECK: }
+
+! CHECK-LABEL: func.func @_QPtaskreduction
+! CHECK-SAME: (%[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
+! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
+! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]]
+! CHECK-SAME {uniq_name = "_QFtaskreductionEx"} : (!fir.box<!fir.array<?xf32>>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
+! CHECK: omp.parallel {
+! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
+! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
+! CHECK: omp.taskgroup task_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_3]] -> %[[VAL_4:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>) {
+! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
+! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
+! CHECK: omp.task in_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
+! [...]
+! CHECK: omp.terminator
+! CHECK: }
+! CHECK: omp.terminator
+! CHECK: }
+! CHECK: omp.terminator
+! CHECK: }
+! CHECK: return
+! CHECK: }
+
+subroutine taskReduction(x)
+ real, dimension(:) :: x
+ !$omp parallel
+ !$omp taskgroup task_reduction(+:x)
+ !$omp task in_reduction(+:x)
+ x = x + 1
+ !$omp end task
+ !$omp end taskgroup
+ !$omp end parallel
+end subroutine
\ No newline at end of file
diff --git a/flang/test/Lower/OpenMP/taskgroup-task_reduction01.f90 b/flang/test/Lower/OpenMP/taskgroup-task_reduction01.f90
new file mode 100644
index 00000000000000..bc32cee93d47f1
--- /dev/null
+++ b/flang/test/Lower/OpenMP/taskgroup-task_reduction01.f90
@@ -0,0 +1,34 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
+!CHECK: ^bb0(%{{.*}}: i32):
+!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK: omp.yield(%[[C0_1]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
+!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+!CHECK: omp.yield(%[[RES]] : i32)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPomp_taskgroup_task_reduction() {
+!CHECK: %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "res", uniq_name = "_QFomp_taskgroup_task_reductionEres"}
+!CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFomp_taskgroup_task_reductionEres"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: omp.taskgroup task_reduction(@[[RED_I32_NAME]] %[[VAL_1]]#0 -> %[[VAL_2:.*]] : !fir.ref<i32>) {
+!CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<i32>
+!CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32
+!CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_3]], %[[VAL_4]] : i32
+!CHECK: hlfir.assign %[[VAL_5]] to %[[VAL_1]]#0 : i32, !fir.ref<i32>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+!CHECK: }
+
+
+subroutine omp_taskgroup_task_reduction()
+ integer :: res
+ !$omp taskgroup task_reduction(+:res)
+ res = res + 1
+ !$omp end taskgroup
+end subroutine omp_taskgroup_task_reduction
\ No newline at end of file
diff --git a/flang/test/Lower/OpenMP/taskgroup-task_reduction02.f90 b/flang/test/Lower/OpenMP/taskgroup-task_reduction02.f90
new file mode 100644
index 00000000000000..6a5bc568efb8e4
--- /dev/null
+++ b/flang/test/Lower/OpenMP/taskgroup-task_reduction02.f90
@@ -0,0 +1,36 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
+!CHECK: ^bb0(%{{.*}}: i32):
+!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK: omp.yield(%[[C0_1]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
+!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+!CHECK: omp.yield(%[[RES]] : i32)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPin_reduction() {
+! [...]
+!CHECK: omp.taskgroup task_reduction(@[[RED_I32_NAME]] %[[VAL_1:.*]]#0 -> %[[VAL_3:.*]] : !fir.ref<i32>) {
+!CHECK: omp.task in_reduction(@[[RED_I32_NAME]] %[[VAL_1]]#0 -> %[[VAL_4:.*]] : !fir.ref<i32>) {
+!CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = "_QFin_reductionEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! [...]
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+!CHECK: }
+
+subroutine in_reduction
+ integer :: x
+ x = 0
+ !$omp taskgroup task_reduction(+:x)
+ !$omp task in_reduction(+:x)
+ x = x + 1
+ !$omp end task
+ !$omp end taskgroup
+end subroutine
\ No newline at end of file
>From 72d223028ff2e6fda8fe90f62d94b6007e1febab Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Fri, 13 Dec 2024 17:21:53 +0530
Subject: [PATCH 2/2] [Flang][OpenMP] Addressed review comments
---
flang/lib/Lower/OpenMP/ClauseProcessor.h | 9 ++---
.../lib/Lower/OpenMP/DataSharingProcessor.cpp | 12 +++---
flang/lib/Lower/OpenMP/OpenMP.cpp | 38 ++++++++-----------
flang/test/Lower/OpenMP/task-inreduction.f90 | 2 +-
.../OpenMP/taskgroup-task-array-reduction.f90 | 2 +-
.../OpenMP/taskgroup-task_reduction01.f90 | 2 +-
.../OpenMP/taskgroup-task_reduction02.f90 | 4 +-
7 files changed, 32 insertions(+), 37 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index e042d3a1efdc82..764964fc706e47 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -107,7 +107,7 @@ class ClauseProcessor {
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms) const;
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
@@ -126,10 +126,9 @@ class ClauseProcessor {
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
- bool processTaskReduction(mlir::Location currentLocation,
- mlir::omp::TaskReductionClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *>
- &TaskReductionSyms) const;
+ bool processTaskReduction(
+ mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index d4377498ccad04..b4422cdd725466 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -344,11 +344,13 @@ void DataSharingProcessor::collectSymbols(
// Collect all symbols referenced in the evaluation being processed,
// that matches 'flag'.
llvm::SetVector<const semantics::Symbol *> allSymbols;
- bool collectSymbols = true;
- for (const omp::Clause &clause : clauses) {
- if (clause.id == llvm::omp::Clause::OMPC_in_reduction)
- collectSymbols = false;
- }
+
+ auto itr = llvm::find_if(clauses, [](const omp::Clause &clause) {
+ return clause.id == llvm::omp::Clause::OMPC_in_reduction;
+ });
+
+ bool collectSymbols = (itr == clauses.end());
+
converter.collectSymbolSet(eval, allSymbols, flag,
/*collectSymbols=*/collectSymbols,
/*collectHostAssociatedSymbols=*/true);
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f657a2ef0a26d1..76270aee904fff 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1253,13 +1253,13 @@ static void genTaskClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
- llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms) {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processDepend(clauseOps);
cp.processFinal(stmtCtx, clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
- cp.processInReduction(loc, clauseOps, InReductionSyms);
+ cp.processInReduction(loc, clauseOps, inReductionSyms);
cp.processMergeable(clauseOps);
cp.processPriority(stmtCtx, clauseOps);
cp.processUntied(clauseOps);
@@ -1888,9 +1888,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::const_iterator item) {
lower::StatementContext stmtCtx;
mlir::omp::TaskOperands clauseOps;
- llvm::SmallVector<const semantics::Symbol *> InReductionSyms;
+ llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
- InReductionSyms);
+ inReductionSyms);
if (!enableDelayedPrivatization)
return genOpWithBody<mlir::omp::TaskOp>(
@@ -1907,7 +1907,7 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
EntryBlockArgs taskArgs;
taskArgs.priv.syms = dsp.getDelayedPrivSymbols();
taskArgs.priv.vars = clauseOps.privateVars;
- taskArgs.inReduction.syms = InReductionSyms;
+ taskArgs.inReduction.syms = inReductionSyms;
taskArgs.inReduction.vars = clauseOps.inReductionVars;
auto genRegionEntryCB = [&](mlir::Operation *op) {
@@ -1927,14 +1927,6 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
auto taskOp =
genOpWithBody<mlir::omp::TaskOp>(genInfo, queue, item, clauseOps);
-
- llvm::SmallVector<mlir::Type> inReductionTypes;
- for (const auto &inreductionVar : clauseOps.inReductionVars)
- inReductionTypes.push_back(inreductionVar.getType());
-
- // Add reduction variables as entry block arguments to the task region
- llvm::SmallVector<mlir::Location> blockArgLocs(InReductionSyms.size(), loc);
- taskOp->getRegion(0).addArguments(inReductionTypes, blockArgLocs);
return taskOp;
}
@@ -1949,21 +1941,23 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
taskReductionSyms);
+ EntryBlockArgs taskgroupArgs;
+ taskgroupArgs.taskReduction.syms = taskReductionSyms;
+ taskgroupArgs.taskReduction.vars = clauseOps.taskReductionVars;
+
+ auto genRegionEntryCB = [&](mlir::Operation *op) {
+ genEntryBlock(converter.getFirOpBuilder(), taskgroupArgs, op->getRegion(0));
+ return llvm::to_vector(taskgroupArgs.getSyms());
+ };
+
OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_taskgroup)
- .setClauses(&item->clauses);
+ .setClauses(&item->clauses)
+ .setGenRegionEntryCb(genRegionEntryCB);
auto taskgroupOp =
genOpWithBody<mlir::omp::TaskgroupOp>(genInfo, queue, item, clauseOps);
-
- llvm::SmallVector<mlir::Type> taskReductionTypes;
- for (const auto &taskreductionVar : clauseOps.taskReductionVars)
- taskReductionTypes.push_back(taskreductionVar.getType());
-
- // Add reduction variables as entry block arguments to the taskgroup region
- llvm::SmallVector<mlir::Location> blockArgLocs(taskReductionSyms.size(), loc);
- taskgroupOp->getRegion(0).addArguments(taskReductionTypes, blockArgLocs);
return taskgroupOp;
}
diff --git a/flang/test/Lower/OpenMP/task-inreduction.f90 b/flang/test/Lower/OpenMP/task-inreduction.f90
index ded4710d5c13d6..41657d320f7d25 100644
--- a/flang/test/Lower/OpenMP/task-inreduction.f90
+++ b/flang/test/Lower/OpenMP/task-inreduction.f90
@@ -32,4 +32,4 @@ subroutine omp_task_in_reduction()
!$omp task in_reduction(+:i)
i = i + 1
!$omp end task
-end subroutine omp_task_in_reduction
\ No newline at end of file
+end subroutine omp_task_in_reduction
diff --git a/flang/test/Lower/OpenMP/taskgroup-task-array-reduction.f90 b/flang/test/Lower/OpenMP/taskgroup-task-array-reduction.f90
index 7e6d7f09fbc679..175242bfc56566 100644
--- a/flang/test/Lower/OpenMP/taskgroup-task-array-reduction.f90
+++ b/flang/test/Lower/OpenMP/taskgroup-task-array-reduction.f90
@@ -46,4 +46,4 @@ subroutine taskReduction(x)
!$omp end task
!$omp end taskgroup
!$omp end parallel
-end subroutine
\ No newline at end of file
+end subroutine
diff --git a/flang/test/Lower/OpenMP/taskgroup-task_reduction01.f90 b/flang/test/Lower/OpenMP/taskgroup-task_reduction01.f90
index bc32cee93d47f1..29da1c56e0b3cb 100644
--- a/flang/test/Lower/OpenMP/taskgroup-task_reduction01.f90
+++ b/flang/test/Lower/OpenMP/taskgroup-task_reduction01.f90
@@ -31,4 +31,4 @@ subroutine omp_taskgroup_task_reduction()
!$omp taskgroup task_reduction(+:res)
res = res + 1
!$omp end taskgroup
-end subroutine omp_taskgroup_task_reduction
\ No newline at end of file
+end subroutine
diff --git a/flang/test/Lower/OpenMP/taskgroup-task_reduction02.f90 b/flang/test/Lower/OpenMP/taskgroup-task_reduction02.f90
index 6a5bc568efb8e4..ad41c1fbc1556c 100644
--- a/flang/test/Lower/OpenMP/taskgroup-task_reduction02.f90
+++ b/flang/test/Lower/OpenMP/taskgroup-task_reduction02.f90
@@ -25,7 +25,7 @@
!CHECK: return
!CHECK: }
-subroutine in_reduction
+subroutine in_reduction()
integer :: x
x = 0
!$omp taskgroup task_reduction(+:x)
@@ -33,4 +33,4 @@ subroutine in_reduction
x = x + 1
!$omp end task
!$omp end taskgroup
-end subroutine
\ No newline at end of file
+end subroutine
More information about the flang-commits
mailing list