[flang-commits] [flang] f11e08f - [flang] Generate fir.do_loop reduce from DO CONCURRENT REDUCE clause (#94718)

via flang-commits flang-commits at lists.llvm.org
Mon Jun 10 08:41:09 PDT 2024


Author: khaki3
Date: 2024-06-10T08:41:05-07:00
New Revision: f11e08fb26642fddebdefca5bec933fe39e4bd03

URL: https://github.com/llvm/llvm-project/commit/f11e08fb26642fddebdefca5bec933fe39e4bd03
DIFF: https://github.com/llvm/llvm-project/commit/f11e08fb26642fddebdefca5bec933fe39e4bd03.diff

LOG: [flang] Generate fir.do_loop reduce from DO CONCURRENT REDUCE clause (#94718)

Derived from #92480. This PR updates the lowering process of DO
CONCURRENT to support F'2023 REDUCE clause. The structure
`IncrementLoopInfo` is extended to have both reduction operations and
symbols in `reduceSymList`. The function `getConcurrentControl`
constructs `reduceSymList` for the innermost loop. Finally,
`genFIRIncrementLoopBegin` builds `fir.do_loop` with reduction operands.

Added: 
    flang/test/Lower/loops3.f90

Modified: 
    flang/lib/Lower/Bridge.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 512c7a349ae21..14e99757925ac 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -104,7 +104,7 @@ struct IncrementLoopInfo {
 
   bool hasLocalitySpecs() const {
     return !localSymList.empty() || !localInitSymList.empty() ||
-           !sharedSymList.empty();
+           !reduceSymList.empty() || !sharedSymList.empty();
   }
 
   // Data members common to both structured and unstructured loops.
@@ -116,6 +116,9 @@ struct IncrementLoopInfo {
   bool isUnordered; // do concurrent, forall
   llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
   llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
+  llvm::SmallVector<
+      std::pair<fir::ReduceOperationEnum, const Fortran::semantics::Symbol *>>
+      reduceSymList;
   llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
   mlir::Value loopVariable = nullptr;
 
@@ -1741,6 +1744,35 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     builder->create<fir::UnreachableOp>(loc);
   }
 
+  fir::ReduceOperationEnum
+  getReduceOperationEnum(const Fortran::parser::ReductionOperator &rOpr) {
+    switch (rOpr.v) {
+    case Fortran::parser::ReductionOperator::Operator::Plus:
+      return fir::ReduceOperationEnum::Add;
+    case Fortran::parser::ReductionOperator::Operator::Multiply:
+      return fir::ReduceOperationEnum::Multiply;
+    case Fortran::parser::ReductionOperator::Operator::And:
+      return fir::ReduceOperationEnum::AND;
+    case Fortran::parser::ReductionOperator::Operator::Or:
+      return fir::ReduceOperationEnum::OR;
+    case Fortran::parser::ReductionOperator::Operator::Eqv:
+      return fir::ReduceOperationEnum::EQV;
+    case Fortran::parser::ReductionOperator::Operator::Neqv:
+      return fir::ReduceOperationEnum::NEQV;
+    case Fortran::parser::ReductionOperator::Operator::Max:
+      return fir::ReduceOperationEnum::MAX;
+    case Fortran::parser::ReductionOperator::Operator::Min:
+      return fir::ReduceOperationEnum::MIN;
+    case Fortran::parser::ReductionOperator::Operator::Iand:
+      return fir::ReduceOperationEnum::IAND;
+    case Fortran::parser::ReductionOperator::Operator::Ior:
+      return fir::ReduceOperationEnum::IOR;
+    case Fortran::parser::ReductionOperator::Operator::Ieor:
+      return fir::ReduceOperationEnum::EIOR;
+    }
+    llvm_unreachable("illegal reduction operator");
+  }
+
   /// Collect DO CONCURRENT or FORALL loop control information.
   IncrementLoopNestInfo getConcurrentControl(
       const Fortran::parser::ConcurrentHeader &header,
@@ -1763,6 +1795,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
               std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
         for (const Fortran::parser::Name &x : localInitList->v)
           info.localInitSymList.push_back(x.symbol);
+      if (const auto *reduceList =
+              std::get_if<Fortran::parser::LocalitySpec::Reduce>(&x.u)) {
+        fir::ReduceOperationEnum reduce_operation = getReduceOperationEnum(
+            std::get<Fortran::parser::ReductionOperator>(reduceList->t));
+        for (const Fortran::parser::Name &x :
+             std::get<std::list<Fortran::parser::Name>>(reduceList->t)) {
+          info.reduceSymList.push_back(
+              std::make_pair(reduce_operation, x.symbol));
+        }
+      }
       if (const auto *sharedList =
               std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
         for (const Fortran::parser::Name &x : sharedList->v)
@@ -1955,9 +1997,23 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         mlir::Type loopVarType = info.getLoopVariableType();
         mlir::Value loopValue;
         if (info.isUnordered) {
+          llvm::SmallVector<mlir::Value> reduceOperands;
+          llvm::SmallVector<mlir::Attribute> reduceAttrs;
+          // Create DO CONCURRENT reduce operands and attributes
+          for (const auto reduceSym : info.reduceSymList) {
+            const fir::ReduceOperationEnum reduce_operation = reduceSym.first;
+            const Fortran::semantics::Symbol *sym = reduceSym.second;
+            fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr);
+            reduceOperands.push_back(fir::getBase(exv));
+            auto reduce_attr =
+                fir::ReduceAttr::get(builder->getContext(), reduce_operation);
+            reduceAttrs.push_back(reduce_attr);
+          }
           // The loop variable value is explicitly updated.
           info.doLoop = builder->create<fir::DoLoopOp>(
-              loc, lowerValue, upperValue, stepValue, /*unordered=*/true);
+              loc, lowerValue, upperValue, stepValue, /*unordered=*/true,
+              /*finalCountValue=*/false, /*iterArgs=*/std::nullopt,
+              llvm::ArrayRef<mlir::Value>(reduceOperands), reduceAttrs);
           builder->setInsertionPointToStart(info.doLoop.getBody());
           loopValue = builder->createConvert(loc, loopVarType,
                                              info.doLoop.getInductionVar());

diff  --git a/flang/test/Lower/loops3.f90 b/flang/test/Lower/loops3.f90
new file mode 100644
index 0000000000000..2e62ee480ec8a
--- /dev/null
+++ b/flang/test/Lower/loops3.f90
@@ -0,0 +1,23 @@
+! Test do concurrent reduction
+! RUN: bbc -emit-fir -hlfir=false -o - %s | FileCheck %s
+
+! CHECK-LABEL: loop_test
+subroutine loop_test
+  integer(4) :: i, j, k, tmp, sum = 0
+  real :: m
+
+  i = 100
+  j = 200
+  k = 300
+
+  ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"}
+  ! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref<i32>
+  ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+  ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+  ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered reduce(#fir.reduce_attr<add> -> %[[VAL_1:.*]] : !fir.ref<i32>, #fir.reduce_attr<max> -> %[[VAL_0:.*]] : !fir.ref<f32>) {
+  do concurrent (i=1:5, j=1:5, k=1:5) local(tmp) reduce(+:sum) reduce(max:m)
+    tmp = i + j + k
+    sum = tmp + sum
+    m = max(m, sum)
+  enddo
+end subroutine loop_test


        


More information about the flang-commits mailing list