[flang-commits] [flang] [flang] Generate fir.do_loop reduce from DO CONCURRENT REDUCE clause (PR #94718)
via flang-commits
flang-commits at lists.llvm.org
Thu Jun 6 21:10:14 PDT 2024
https://github.com/khaki3 created https://github.com/llvm/llvm-project/pull/94718
Derived from #92480. This PR updates the lowering process of DO CONCURRENT to support F'2023 REDUCE clause. The structure `IncrementLoopInfo` is extended to have both reduction operations and symbols in `reduceSymList`. The function `getConcurrentControl` constructs `reduceSymList` for the innermost loop. Finally, `genFIRIncrementLoopBegin` builds `fir.do_loop` with reduction operands.
>From b069eff6b633517ae848ce01d548b39b3101b5c2 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Thu, 6 Jun 2024 15:56:31 -0700
Subject: [PATCH] [flang] Generate fir.do_loop reduce from DO CONCURRENT REDUCE
clause
---
flang/lib/Lower/Bridge.cpp | 61 +++++++++++++++++++++++++++++++++++--
flang/test/Lower/loops3.f90 | 23 ++++++++++++++
2 files changed, 82 insertions(+), 2 deletions(-)
create mode 100644 flang/test/Lower/loops3.f90
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 512c7a349ae21..d0a0a36500f61 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -104,7 +104,7 @@ struct IncrementLoopInfo {
bool hasLocalitySpecs() const {
return !localSymList.empty() || !localInitSymList.empty() ||
- !sharedSymList.empty();
+ !reduceSymList.empty() || !sharedSymList.empty();
}
// Data members common to both structured and unstructured loops.
@@ -116,6 +116,9 @@ struct IncrementLoopInfo {
bool isUnordered; // do concurrent, forall
llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
+ llvm::SmallVector<
+ std::pair<fir::ReduceOperationEnum, const Fortran::semantics::Symbol *>>
+ reduceSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
mlir::Value loopVariable = nullptr;
@@ -1741,6 +1744,36 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->create<fir::UnreachableOp>(loc);
}
+ fir::ReduceOperationEnum
+ getReduceOperationEnum(const Fortran::parser::ReductionOperator &rOpr) {
+ switch (rOpr.v) {
+ case Fortran::parser::ReductionOperator::Operator::Plus:
+ return fir::ReduceOperationEnum::Add;
+ case Fortran::parser::ReductionOperator::Operator::Multiply:
+ return fir::ReduceOperationEnum::Multiply;
+ case Fortran::parser::ReductionOperator::Operator::And:
+ return fir::ReduceOperationEnum::AND;
+ case Fortran::parser::ReductionOperator::Operator::Or:
+ return fir::ReduceOperationEnum::OR;
+ case Fortran::parser::ReductionOperator::Operator::Eqv:
+ return fir::ReduceOperationEnum::EQV;
+ case Fortran::parser::ReductionOperator::Operator::Neqv:
+ return fir::ReduceOperationEnum::NEQV;
+ case Fortran::parser::ReductionOperator::Operator::Max:
+ return fir::ReduceOperationEnum::MAX;
+ case Fortran::parser::ReductionOperator::Operator::Min:
+ return fir::ReduceOperationEnum::MIN;
+ case Fortran::parser::ReductionOperator::Operator::Iand:
+ return fir::ReduceOperationEnum::IAND;
+ case Fortran::parser::ReductionOperator::Operator::Ior:
+ return fir::ReduceOperationEnum::IOR;
+ case Fortran::parser::ReductionOperator::Operator::Ieor:
+ return fir::ReduceOperationEnum::EIOR;
+ }
+ fir::emitFatalError(toLocation(), "illegal reduction operator");
+ return fir::ReduceOperationEnum::Add;
+ }
+
/// Collect DO CONCURRENT or FORALL loop control information.
IncrementLoopNestInfo getConcurrentControl(
const Fortran::parser::ConcurrentHeader &header,
@@ -1763,6 +1796,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
for (const Fortran::parser::Name &x : localInitList->v)
info.localInitSymList.push_back(x.symbol);
+ if (const auto *reduceList =
+ std::get_if<Fortran::parser::LocalitySpec::Reduce>(&x.u)) {
+ fir::ReduceOperationEnum reduce_operation = getReduceOperationEnum(
+ std::get<Fortran::parser::ReductionOperator>(reduceList->t));
+ for (const Fortran::parser::Name &x :
+ std::get<std::list<Fortran::parser::Name>>(reduceList->t)) {
+ info.reduceSymList.push_back(
+ std::make_pair(reduce_operation, x.symbol));
+ }
+ }
if (const auto *sharedList =
std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
for (const Fortran::parser::Name &x : sharedList->v)
@@ -1955,9 +1998,23 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::Type loopVarType = info.getLoopVariableType();
mlir::Value loopValue;
if (info.isUnordered) {
+ llvm::SmallVector<mlir::Value> reduceOperands;
+ llvm::SmallVector<mlir::Attribute> reduceAttrs;
+ // Create DO CONCURRENT reduce operations and attributes
+ for (const auto reduceSym : info.reduceSymList) {
+ const fir::ReduceOperationEnum reduce_operation = reduceSym.first;
+ const Fortran::semantics::Symbol *sym = reduceSym.second;
+ fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr);
+ reduceOperands.push_back(fir::getBase(exv));
+ auto reduce_attr =
+ fir::ReduceAttr::get(builder->getContext(), reduce_operation);
+ reduceAttrs.push_back(reduce_attr);
+ }
// The loop variable value is explicitly updated.
info.doLoop = builder->create<fir::DoLoopOp>(
- loc, lowerValue, upperValue, stepValue, /*unordered=*/true);
+ loc, lowerValue, upperValue, stepValue, /*unordered=*/true,
+ /*finalCountValue=*/false, /*iterArgs=*/std::nullopt,
+ llvm::ArrayRef<mlir::Value>(reduceOperands), reduceAttrs);
builder->setInsertionPointToStart(info.doLoop.getBody());
loopValue = builder->createConvert(loc, loopVarType,
info.doLoop.getInductionVar());
diff --git a/flang/test/Lower/loops3.f90 b/flang/test/Lower/loops3.f90
new file mode 100644
index 0000000000000..dd24e26d72c31
--- /dev/null
+++ b/flang/test/Lower/loops3.f90
@@ -0,0 +1,23 @@
+! Test do concurrent reduction
+! RUN: bbc -emit-fir -hlfir=false -o - %s | FileCheck %s
+
+! CHECK-LABEL: loop_test
+subroutine loop_test
+ integer(4) :: i, j, k, tmp, sum = 0
+ real :: m
+
+ i = 100
+ j = 200
+ k = 300
+
+ ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"
+ ! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref<i32>
+ ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+ ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+ ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered reduce(#fir.reduce_attr<add> -> %[[VAL_1:.*]] : !fir.ref<i32>, #fir.reduce_attr<max> -> %[[VAL_0:.*]] : !fir.ref<f32>) {
+ do concurrent (i=1:5, j=1:5, k=1:5) local(tmp) reduce(+:sum) reduce(max:m)
+ tmp = i + j + k
+ sum = tmp + sum
+ m = max(m, sum)
+ enddo
+end subroutine loop_test
More information about the flang-commits
mailing list