[flang-commits] [flang] [Flang] Hoist concurrent-limit and concurrent-step expressions outsid… (PR #111665)

via flang-commits flang-commits at lists.llvm.org
Wed Oct 9 04:53:24 PDT 2024


https://github.com/harishch4 created https://github.com/llvm/llvm-project/pull/111665

…e the outer most do concurrent loop

>From 5e7fcb0b6daae36278cb93c3cef39a3d6e418560 Mon Sep 17 00:00:00 2001
From: harishch4 <harishcse44 at gmail.com>
Date: Wed, 9 Oct 2024 11:42:19 +0000
Subject: [PATCH] [Flang] Hoist concurrent-limit and concurrent-step
 expressions outside the outer most do concurrent loop

---
 flang/include/flang/Lower/PFTBuilder.h |  1 +
 flang/lib/Lower/Bridge.cpp             | 25 +++++++++++--
 flang/test/Lower/do_concurrent.f90     | 50 ++++++++++++++++++++++++++
 3 files changed, 73 insertions(+), 3 deletions(-)
 create mode 100644 flang/test/Lower/do_concurrent.f90

diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 7f1b93c564b4c4..de9abea3014356 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -362,6 +362,7 @@ struct Evaluation : EvaluationVariant {
   bool activeConstruct{false}; // temporarily set for some constructs
   mlir::Block *block{nullptr}; // isNewBlock block (ActionStmt, ConstructStmt)
   int printIndex{0}; // (ActionStmt, ConstructStmt) evaluation index for dumps
+  mlir::Operation *op{nullptr}; // associated mlir operation
 };
 
 using ProgramVariant =
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0894a5903635e1..02871009020dd6 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2012,6 +2012,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     IncrementLoopNestInfo incrementLoopNestInfo;
     const Fortran::parser::ScalarLogicalExpr *whileCondition = nullptr;
     bool infiniteLoop = !loopControl.has_value();
+    bool isConcurrent = false;
     if (infiniteLoop) {
       assert(unstructuredContext && "infinite loop must be unstructured");
       startBlock(headerBlock);
@@ -2042,6 +2043,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           std::get_if<Fortran::parser::LoopControl::Concurrent>(
               &loopControl->u);
       assert(concurrent && "invalid DO loop variant");
+      isConcurrent = true;
       incrementLoopNestInfo = getConcurrentControl(
           std::get<Fortran::parser::ConcurrentHeader>(concurrent->t),
           std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent->t));
@@ -2070,7 +2072,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
     // Increment loop begin code. (Infinite/while code was already generated.)
     if (!infiniteLoop && !whileCondition)
-      genFIRIncrementLoopBegin(incrementLoopNestInfo, doStmtEval.dirs);
+      genFIRIncrementLoopBegin(incrementLoopNestInfo, doStmtEval.dirs,
+                               isConcurrent);
 
     // Loop body code.
     auto iter = eval.getNestedEvaluations().begin();
@@ -2128,12 +2131,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Generate FIR to begin a structured or unstructured increment loop nest.
   void genFIRIncrementLoopBegin(
       IncrementLoopNestInfo &incrementLoopNestInfo,
-      llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
+      llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs,
+      bool isConcurrent) {
     assert(!incrementLoopNestInfo.empty() && "empty loop nest");
     mlir::Location loc = toLocation();
+    Fortran::lower::pft::Evaluation &eval = getEval();
+    Fortran::lower::pft::Evaluation *outermostEval = nullptr;
+    if (isConcurrent) {
+      outermostEval = &eval;
+      while (outermostEval->parentConstruct) {
+        outermostEval = outermostEval->parentConstruct;
+      }
+    }
+    mlir::OpBuilder::InsertPoint insertPt;
     for (IncrementLoopInfo &info : incrementLoopNestInfo) {
       info.loopVariable =
           genLoopVariableAddress(loc, *info.loopVariableSym, info.isUnordered);
+      if (outermostEval && outermostEval->op) {
+        insertPt = builder->saveInsertionPoint();
+        builder->setInsertionPoint(outermostEval->op);
+      }
       mlir::Value lowerValue = genControlValue(info.lowerExpr, info);
       mlir::Value upperValue = genControlValue(info.upperExpr, info);
       bool isConst = true;
@@ -2144,7 +2161,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         info.stepVariable = builder->createTemporary(loc, stepValue.getType());
         builder->create<fir::StoreOp>(loc, stepValue, info.stepVariable);
       }
-
+      if (outermostEval && outermostEval->op)
+        builder->restoreInsertionPoint(insertPt);
       // Structured loop - generate fir.do_loop.
       if (info.isStructured()) {
         mlir::Type loopVarType = info.getLoopVariableType();
@@ -2179,6 +2197,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           builder->setInsertionPointToStart(info.doLoop.getBody());
           loopValue = info.doLoop.getRegionIterArgs()[0];
         }
+        eval.op = info.doLoop;
         // Update the loop variable value in case it has non-index references.
         builder->create<fir::StoreOp>(loc, loopValue, info.loopVariable);
         if (info.maskExpr) {
diff --git a/flang/test/Lower/do_concurrent.f90 b/flang/test/Lower/do_concurrent.f90
new file mode 100644
index 00000000000000..3a09ed97ebd84d
--- /dev/null
+++ b/flang/test/Lower/do_concurrent.f90
@@ -0,0 +1,50 @@
+! RUN: %flang_fc1 -emit-hlfir -o - %s | FileCheck %s
+
+! Simple tests for structured concurrent loops with loop-control.
+
+pure function bar(n, m)
+   implicit none
+   integer, intent(in) :: n, m
+   integer :: bar
+   bar = n + m
+end function
+
+subroutine sub1(n)
+   implicit none
+   integer :: n, m, i, j
+   integer, dimension(n) :: a
+!CHECK: %[[LB1:.*]] = arith.constant 1 : i32
+!CHECK: %[[LB1_CVT:.*]] = fir.convert %[[LB1]] : (i32) -> index
+!CHECK: %[[UB1:.*]] = fir.load %5#0 : !fir.ref<i32>
+!CHECK: %[[UB1_CVT:.*]] = fir.convert %[[UB1]] : (i32) -> index
+!CHECK: %[[LB2:.*]] = arith.constant 1 : i32
+!CHECK: %[[LB2_CVT:.*]] = fir.convert %[[LB2]] : (i32) -> index
+!CHECK: %[[UB2:.*]] = fir.call @_QPbar(%{{.*}}, %{{.*}}) proc_attrs<pure> fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> i32
+!CHECK: %[[UB2_CVT:.*]] = fir.convert %[[UB2]] : (i32) -> index
+!CHECK: fir.do_loop %{{.*}} = %[[LB1_CVT]] to %[[UB1_CVT]] step %{{.*}} unordered
+!CHECK: fir.do_loop %{{.*}} = %[[LB2_CVT]] to %[[UB2_CVT]] step %{{.*}} unordered
+   do concurrent(i=1:n, j=1:bar(n*m, n/m))
+      a(i) = n
+   end do
+end subroutine
+
+subroutine sub2(n)
+   implicit none
+   integer :: n, m, i, j
+   integer, dimension(n) :: a
+!CHECK: %[[LB1:.*]] = arith.constant 1 : i32
+!CHECK: %[[LB1_CVT:.*]] = fir.convert %[[LB1]] : (i32) -> index
+!CHECK: %[[UB1:.*]] = fir.load %5#0 : !fir.ref<i32>
+!CHECK: %[[UB1_CVT:.*]] = fir.convert %[[UB1]] : (i32) -> index
+!CHECK: %[[LB2:.*]] = arith.constant 1 : i32
+!CHECK: %[[LB2_CVT:.*]] = fir.convert %[[LB2]] : (i32) -> index
+!CHECK: %[[UB2:.*]] = fir.call @_QPbar(%{{.*}}, %{{.*}}) proc_attrs<pure> fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> i32
+!CHECK: %[[UB2_CVT:.*]] = fir.convert %[[UB2]] : (i32) -> index
+!CHECK: fir.do_loop %{{.*}} = %[[LB1_CVT]] to %[[UB1_CVT]] step %{{.*}} unordered
+!CHECK: fir.do_loop %{{.*}} = %[[LB2_CVT]] to %[[UB2_CVT]] step %{{.*}} unordered
+   do concurrent(i=1:n)
+      do concurrent(j=1:bar(n*m, n/m))
+         a(i) = n
+      end do
+   end do
+end subroutine



More information about the flang-commits mailing list