[flang-commits] [flang] f11e08f - [flang] Generate fir.do_loop reduce from DO CONCURRENT REDUCE clause (#94718)
via flang-commits
flang-commits at lists.llvm.org
Mon Jun 10 08:41:09 PDT 2024
Author: khaki3
Date: 2024-06-10T08:41:05-07:00
New Revision: f11e08fb26642fddebdefca5bec933fe39e4bd03
URL: https://github.com/llvm/llvm-project/commit/f11e08fb26642fddebdefca5bec933fe39e4bd03
DIFF: https://github.com/llvm/llvm-project/commit/f11e08fb26642fddebdefca5bec933fe39e4bd03.diff
LOG: [flang] Generate fir.do_loop reduce from DO CONCURRENT REDUCE clause (#94718)
Derived from #92480. This PR updates the lowering process of DO
CONCURRENT to support F'2023 REDUCE clause. The structure
`IncrementLoopInfo` is extended to have both reduction operations and
symbols in `reduceSymList`. The function `getConcurrentControl`
constructs `reduceSymList` for the innermost loop. Finally,
`genFIRIncrementLoopBegin` builds `fir.do_loop` with reduction operands.
Added:
flang/test/Lower/loops3.f90
Modified:
flang/lib/Lower/Bridge.cpp
Removed:
################################################################################
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 512c7a349ae21..14e99757925ac 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -104,7 +104,7 @@ struct IncrementLoopInfo {
bool hasLocalitySpecs() const {
return !localSymList.empty() || !localInitSymList.empty() ||
- !sharedSymList.empty();
+ !reduceSymList.empty() || !sharedSymList.empty();
}
// Data members common to both structured and unstructured loops.
@@ -116,6 +116,9 @@ struct IncrementLoopInfo {
bool isUnordered; // do concurrent, forall
llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
+ llvm::SmallVector<
+ std::pair<fir::ReduceOperationEnum, const Fortran::semantics::Symbol *>>
+ reduceSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
mlir::Value loopVariable = nullptr;
@@ -1741,6 +1744,35 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->create<fir::UnreachableOp>(loc);
}
+ fir::ReduceOperationEnum
+ getReduceOperationEnum(const Fortran::parser::ReductionOperator &rOpr) {
+ switch (rOpr.v) {
+ case Fortran::parser::ReductionOperator::Operator::Plus:
+ return fir::ReduceOperationEnum::Add;
+ case Fortran::parser::ReductionOperator::Operator::Multiply:
+ return fir::ReduceOperationEnum::Multiply;
+ case Fortran::parser::ReductionOperator::Operator::And:
+ return fir::ReduceOperationEnum::AND;
+ case Fortran::parser::ReductionOperator::Operator::Or:
+ return fir::ReduceOperationEnum::OR;
+ case Fortran::parser::ReductionOperator::Operator::Eqv:
+ return fir::ReduceOperationEnum::EQV;
+ case Fortran::parser::ReductionOperator::Operator::Neqv:
+ return fir::ReduceOperationEnum::NEQV;
+ case Fortran::parser::ReductionOperator::Operator::Max:
+ return fir::ReduceOperationEnum::MAX;
+ case Fortran::parser::ReductionOperator::Operator::Min:
+ return fir::ReduceOperationEnum::MIN;
+ case Fortran::parser::ReductionOperator::Operator::Iand:
+ return fir::ReduceOperationEnum::IAND;
+ case Fortran::parser::ReductionOperator::Operator::Ior:
+ return fir::ReduceOperationEnum::IOR;
+ case Fortran::parser::ReductionOperator::Operator::Ieor:
+ return fir::ReduceOperationEnum::EIOR;
+ }
+ llvm_unreachable("illegal reduction operator");
+ }
+
/// Collect DO CONCURRENT or FORALL loop control information.
IncrementLoopNestInfo getConcurrentControl(
const Fortran::parser::ConcurrentHeader &header,
@@ -1763,6 +1795,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
for (const Fortran::parser::Name &x : localInitList->v)
info.localInitSymList.push_back(x.symbol);
+ if (const auto *reduceList =
+ std::get_if<Fortran::parser::LocalitySpec::Reduce>(&x.u)) {
+ fir::ReduceOperationEnum reduce_operation = getReduceOperationEnum(
+ std::get<Fortran::parser::ReductionOperator>(reduceList->t));
+ for (const Fortran::parser::Name &x :
+ std::get<std::list<Fortran::parser::Name>>(reduceList->t)) {
+ info.reduceSymList.push_back(
+ std::make_pair(reduce_operation, x.symbol));
+ }
+ }
if (const auto *sharedList =
std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
for (const Fortran::parser::Name &x : sharedList->v)
@@ -1955,9 +1997,23 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::Type loopVarType = info.getLoopVariableType();
mlir::Value loopValue;
if (info.isUnordered) {
+ llvm::SmallVector<mlir::Value> reduceOperands;
+ llvm::SmallVector<mlir::Attribute> reduceAttrs;
+ // Create DO CONCURRENT reduce operands and attributes
+ for (const auto reduceSym : info.reduceSymList) {
+ const fir::ReduceOperationEnum reduce_operation = reduceSym.first;
+ const Fortran::semantics::Symbol *sym = reduceSym.second;
+ fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr);
+ reduceOperands.push_back(fir::getBase(exv));
+ auto reduce_attr =
+ fir::ReduceAttr::get(builder->getContext(), reduce_operation);
+ reduceAttrs.push_back(reduce_attr);
+ }
// The loop variable value is explicitly updated.
info.doLoop = builder->create<fir::DoLoopOp>(
- loc, lowerValue, upperValue, stepValue, /*unordered=*/true);
+ loc, lowerValue, upperValue, stepValue, /*unordered=*/true,
+ /*finalCountValue=*/false, /*iterArgs=*/std::nullopt,
+ llvm::ArrayRef<mlir::Value>(reduceOperands), reduceAttrs);
builder->setInsertionPointToStart(info.doLoop.getBody());
loopValue = builder->createConvert(loc, loopVarType,
info.doLoop.getInductionVar());
diff --git a/flang/test/Lower/loops3.f90 b/flang/test/Lower/loops3.f90
new file mode 100644
index 0000000000000..2e62ee480ec8a
--- /dev/null
+++ b/flang/test/Lower/loops3.f90
@@ -0,0 +1,23 @@
+! Test do concurrent reduction
+! RUN: bbc -emit-fir -hlfir=false -o - %s | FileCheck %s
+
+! CHECK-LABEL: loop_test
+subroutine loop_test
+ integer(4) :: i, j, k, tmp, sum = 0
+ real :: m
+
+ i = 100
+ j = 200
+ k = 300
+
+ ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"}
+ ! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref<i32>
+ ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+ ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+ ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered reduce(#fir.reduce_attr<add> -> %[[VAL_1:.*]] : !fir.ref<i32>, #fir.reduce_attr<max> -> %[[VAL_0:.*]] : !fir.ref<f32>) {
+ do concurrent (i=1:5, j=1:5, k=1:5) local(tmp) reduce(+:sum) reduce(max:m)
+ tmp = i + j + k
+ sum = tmp + sum
+ m = max(m, sum)
+ enddo
+end subroutine loop_test
More information about the flang-commits
mailing list