[flang-commits] [flang] [flang] Generate fir.do_loop reduce from DO CONCURRENT REDUCE clause (PR #94718)

Congcong Cai via flang-commits flang-commits at lists.llvm.org
Thu Jun 6 22:32:29 PDT 2024


https://github.com/HerrCai0907 updated https://github.com/llvm/llvm-project/pull/94718

>From b069eff6b633517ae848ce01d548b39b3101b5c2 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Thu, 6 Jun 2024 15:56:31 -0700
Subject: [PATCH 1/4] [flang] Generate fir.do_loop reduce from DO CONCURRENT
 REDUCE clause

---
 flang/lib/Lower/Bridge.cpp  | 61 +++++++++++++++++++++++++++++++++++--
 flang/test/Lower/loops3.f90 | 23 ++++++++++++++
 2 files changed, 82 insertions(+), 2 deletions(-)
 create mode 100644 flang/test/Lower/loops3.f90

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 512c7a349ae21..d0a0a36500f61 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -104,7 +104,7 @@ struct IncrementLoopInfo {
 
   bool hasLocalitySpecs() const {
     return !localSymList.empty() || !localInitSymList.empty() ||
-           !sharedSymList.empty();
+           !reduceSymList.empty() || !sharedSymList.empty();
   }
 
   // Data members common to both structured and unstructured loops.
@@ -116,6 +116,9 @@ struct IncrementLoopInfo {
   bool isUnordered; // do concurrent, forall
   llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
   llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
+  llvm::SmallVector<
+      std::pair<fir::ReduceOperationEnum, const Fortran::semantics::Symbol *>>
+      reduceSymList;
   llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
   mlir::Value loopVariable = nullptr;
 
@@ -1741,6 +1744,36 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     builder->create<fir::UnreachableOp>(loc);
   }
 
+  fir::ReduceOperationEnum
+  getReduceOperationEnum(const Fortran::parser::ReductionOperator &rOpr) {
+    switch (rOpr.v) {
+    case Fortran::parser::ReductionOperator::Operator::Plus:
+      return fir::ReduceOperationEnum::Add;
+    case Fortran::parser::ReductionOperator::Operator::Multiply:
+      return fir::ReduceOperationEnum::Multiply;
+    case Fortran::parser::ReductionOperator::Operator::And:
+      return fir::ReduceOperationEnum::AND;
+    case Fortran::parser::ReductionOperator::Operator::Or:
+      return fir::ReduceOperationEnum::OR;
+    case Fortran::parser::ReductionOperator::Operator::Eqv:
+      return fir::ReduceOperationEnum::EQV;
+    case Fortran::parser::ReductionOperator::Operator::Neqv:
+      return fir::ReduceOperationEnum::NEQV;
+    case Fortran::parser::ReductionOperator::Operator::Max:
+      return fir::ReduceOperationEnum::MAX;
+    case Fortran::parser::ReductionOperator::Operator::Min:
+      return fir::ReduceOperationEnum::MIN;
+    case Fortran::parser::ReductionOperator::Operator::Iand:
+      return fir::ReduceOperationEnum::IAND;
+    case Fortran::parser::ReductionOperator::Operator::Ior:
+      return fir::ReduceOperationEnum::IOR;
+    case Fortran::parser::ReductionOperator::Operator::Ieor:
+      return fir::ReduceOperationEnum::EIOR;
+    }
+    fir::emitFatalError(toLocation(), "illegal reduction operator");
+    return fir::ReduceOperationEnum::Add;
+  }
+
   /// Collect DO CONCURRENT or FORALL loop control information.
   IncrementLoopNestInfo getConcurrentControl(
       const Fortran::parser::ConcurrentHeader &header,
@@ -1763,6 +1796,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
               std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
         for (const Fortran::parser::Name &x : localInitList->v)
           info.localInitSymList.push_back(x.symbol);
+      if (const auto *reduceList =
+              std::get_if<Fortran::parser::LocalitySpec::Reduce>(&x.u)) {
+        fir::ReduceOperationEnum reduce_operation = getReduceOperationEnum(
+            std::get<Fortran::parser::ReductionOperator>(reduceList->t));
+        for (const Fortran::parser::Name &x :
+             std::get<std::list<Fortran::parser::Name>>(reduceList->t)) {
+          info.reduceSymList.push_back(
+              std::make_pair(reduce_operation, x.symbol));
+        }
+      }
       if (const auto *sharedList =
               std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
         for (const Fortran::parser::Name &x : sharedList->v)
@@ -1955,9 +1998,23 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         mlir::Type loopVarType = info.getLoopVariableType();
         mlir::Value loopValue;
         if (info.isUnordered) {
+          llvm::SmallVector<mlir::Value> reduceOperands;
+          llvm::SmallVector<mlir::Attribute> reduceAttrs;
+          // Create DO CONCURRENT reduce operations and attributes
+          for (const auto reduceSym : info.reduceSymList) {
+            const fir::ReduceOperationEnum reduce_operation = reduceSym.first;
+            const Fortran::semantics::Symbol *sym = reduceSym.second;
+            fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr);
+            reduceOperands.push_back(fir::getBase(exv));
+            auto reduce_attr =
+                fir::ReduceAttr::get(builder->getContext(), reduce_operation);
+            reduceAttrs.push_back(reduce_attr);
+          }
           // The loop variable value is explicitly updated.
           info.doLoop = builder->create<fir::DoLoopOp>(
-              loc, lowerValue, upperValue, stepValue, /*unordered=*/true);
+              loc, lowerValue, upperValue, stepValue, /*unordered=*/true,
+              /*finalCountValue=*/false, /*iterArgs=*/std::nullopt,
+              llvm::ArrayRef<mlir::Value>(reduceOperands), reduceAttrs);
           builder->setInsertionPointToStart(info.doLoop.getBody());
           loopValue = builder->createConvert(loc, loopVarType,
                                              info.doLoop.getInductionVar());
diff --git a/flang/test/Lower/loops3.f90 b/flang/test/Lower/loops3.f90
new file mode 100644
index 0000000000000..dd24e26d72c31
--- /dev/null
+++ b/flang/test/Lower/loops3.f90
@@ -0,0 +1,23 @@
+! Test do concurrent reduction
+! RUN: bbc -emit-fir -hlfir=false -o - %s | FileCheck %s
+
+! CHECK-LABEL: loop_test
+subroutine loop_test
+  integer(4) :: i, j, k, tmp, sum = 0
+  real :: m
+
+  i = 100
+  j = 200
+  k = 300
+
+  ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"
+  ! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref<i32>
+  ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+  ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+  ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered reduce(#fir.reduce_attr<add> -> %[[VAL_1:.*]] : !fir.ref<i32>, #fir.reduce_attr<max> -> %[[VAL_0:.*]] : !fir.ref<f32>) {
+  do concurrent (i=1:5, j=1:5, k=1:5) local(tmp) reduce(+:sum) reduce(max:m)
+    tmp = i + j + k
+    sum = tmp + sum
+    m = max(m, sum)
+  enddo
+end subroutine loop_test

>From c2607709f550caff3bdc6fe71329d941b87bd244 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Thu, 6 Jun 2024 21:14:10 -0700
Subject: [PATCH 2/4] [flang] Close a brace

---
 flang/test/Lower/loops3.f90 | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/test/Lower/loops3.f90 b/flang/test/Lower/loops3.f90
index dd24e26d72c31..2e62ee480ec8a 100644
--- a/flang/test/Lower/loops3.f90
+++ b/flang/test/Lower/loops3.f90
@@ -10,7 +10,7 @@ subroutine loop_test
   j = 200
   k = 300
 
-  ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"
+  ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"}
   ! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref<i32>
   ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
   ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {

>From 3512291f5da53d5f9b9a0340822352cc1326f7d6 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Thu, 6 Jun 2024 21:39:51 -0700
Subject: [PATCH 3/4] [flang] Fix a comment

---
 flang/lib/Lower/Bridge.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index d0a0a36500f61..dca71256192ed 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2000,7 +2000,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         if (info.isUnordered) {
           llvm::SmallVector<mlir::Value> reduceOperands;
           llvm::SmallVector<mlir::Attribute> reduceAttrs;
-          // Create DO CONCURRENT reduce operations and attributes
+          // Create DO CONCURRENT reduce operands and attributes
           for (const auto reduceSym : info.reduceSymList) {
             const fir::ReduceOperationEnum reduce_operation = reduceSym.first;
             const Fortran::semantics::Symbol *sym = reduceSym.second;

>From f9756968593f777cf3f12027eea2b89c55a2ea63 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Thu, 6 Jun 2024 22:30:52 -0700
Subject: [PATCH 4/4] [flang] Use llvm_unreachable

---
 flang/lib/Lower/Bridge.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index dca71256192ed..14e99757925ac 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1770,8 +1770,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     case Fortran::parser::ReductionOperator::Operator::Ieor:
       return fir::ReduceOperationEnum::EIOR;
     }
-    fir::emitFatalError(toLocation(), "illegal reduction operator");
-    return fir::ReduceOperationEnum::Add;
+    llvm_unreachable("illegal reduction operator");
   }
 
   /// Collect DO CONCURRENT or FORALL loop control information.



More information about the flang-commits mailing list