[flang-commits] [flang] [Flang] Hoist concurrent-limit and concurrent-step expressions outsid… (PR #111665)
via flang-commits
flang-commits at lists.llvm.org
Wed Oct 9 04:53:24 PDT 2024
https://github.com/harishch4 created https://github.com/llvm/llvm-project/pull/111665
…e the outer most do concurrent loop
>From 5e7fcb0b6daae36278cb93c3cef39a3d6e418560 Mon Sep 17 00:00:00 2001
From: harishch4 <harishcse44 at gmail.com>
Date: Wed, 9 Oct 2024 11:42:19 +0000
Subject: [PATCH] [Flang] Hoist concurrent-limit and concurrent-step
expressions outside the outer most do concurrent loop
---
flang/include/flang/Lower/PFTBuilder.h | 1 +
flang/lib/Lower/Bridge.cpp | 25 +++++++++++--
flang/test/Lower/do_concurrent.f90 | 50 ++++++++++++++++++++++++++
3 files changed, 73 insertions(+), 3 deletions(-)
create mode 100644 flang/test/Lower/do_concurrent.f90
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 7f1b93c564b4c4..de9abea3014356 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -362,6 +362,7 @@ struct Evaluation : EvaluationVariant {
bool activeConstruct{false}; // temporarily set for some constructs
mlir::Block *block{nullptr}; // isNewBlock block (ActionStmt, ConstructStmt)
int printIndex{0}; // (ActionStmt, ConstructStmt) evaluation index for dumps
+ mlir::Operation *op{nullptr}; // associated mlir operation
};
using ProgramVariant =
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0894a5903635e1..02871009020dd6 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2012,6 +2012,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
IncrementLoopNestInfo incrementLoopNestInfo;
const Fortran::parser::ScalarLogicalExpr *whileCondition = nullptr;
bool infiniteLoop = !loopControl.has_value();
+ bool isConcurrent = false;
if (infiniteLoop) {
assert(unstructuredContext && "infinite loop must be unstructured");
startBlock(headerBlock);
@@ -2042,6 +2043,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::LoopControl::Concurrent>(
&loopControl->u);
assert(concurrent && "invalid DO loop variant");
+ isConcurrent = true;
incrementLoopNestInfo = getConcurrentControl(
std::get<Fortran::parser::ConcurrentHeader>(concurrent->t),
std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent->t));
@@ -2070,7 +2072,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Increment loop begin code. (Infinite/while code was already generated.)
if (!infiniteLoop && !whileCondition)
- genFIRIncrementLoopBegin(incrementLoopNestInfo, doStmtEval.dirs);
+ genFIRIncrementLoopBegin(incrementLoopNestInfo, doStmtEval.dirs,
+ isConcurrent);
// Loop body code.
auto iter = eval.getNestedEvaluations().begin();
@@ -2128,12 +2131,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Generate FIR to begin a structured or unstructured increment loop nest.
void genFIRIncrementLoopBegin(
IncrementLoopNestInfo &incrementLoopNestInfo,
- llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
+ llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs,
+ bool isConcurrent) {
assert(!incrementLoopNestInfo.empty() && "empty loop nest");
mlir::Location loc = toLocation();
+ Fortran::lower::pft::Evaluation &eval = getEval();
+ Fortran::lower::pft::Evaluation *outermostEval = nullptr;
+ if (isConcurrent) {
+ outermostEval = &eval;
+ while (outermostEval->parentConstruct) {
+ outermostEval = outermostEval->parentConstruct;
+ }
+ }
+ mlir::OpBuilder::InsertPoint insertPt;
for (IncrementLoopInfo &info : incrementLoopNestInfo) {
info.loopVariable =
genLoopVariableAddress(loc, *info.loopVariableSym, info.isUnordered);
+ if (outermostEval && outermostEval->op) {
+ insertPt = builder->saveInsertionPoint();
+ builder->setInsertionPoint(outermostEval->op);
+ }
mlir::Value lowerValue = genControlValue(info.lowerExpr, info);
mlir::Value upperValue = genControlValue(info.upperExpr, info);
bool isConst = true;
@@ -2144,7 +2161,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
info.stepVariable = builder->createTemporary(loc, stepValue.getType());
builder->create<fir::StoreOp>(loc, stepValue, info.stepVariable);
}
-
+ if (outermostEval && outermostEval->op)
+ builder->restoreInsertionPoint(insertPt);
// Structured loop - generate fir.do_loop.
if (info.isStructured()) {
mlir::Type loopVarType = info.getLoopVariableType();
@@ -2179,6 +2197,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->setInsertionPointToStart(info.doLoop.getBody());
loopValue = info.doLoop.getRegionIterArgs()[0];
}
+ eval.op = info.doLoop;
// Update the loop variable value in case it has non-index references.
builder->create<fir::StoreOp>(loc, loopValue, info.loopVariable);
if (info.maskExpr) {
diff --git a/flang/test/Lower/do_concurrent.f90 b/flang/test/Lower/do_concurrent.f90
new file mode 100644
index 00000000000000..3a09ed97ebd84d
--- /dev/null
+++ b/flang/test/Lower/do_concurrent.f90
@@ -0,0 +1,50 @@
+! RUN: %flang_fc1 -emit-hlfir -o - %s | FileCheck %s
+
+! Simple tests for structured concurrent loops with loop-control.
+
+pure function bar(n, m)
+ implicit none
+ integer, intent(in) :: n, m
+ integer :: bar
+ bar = n + m
+end function
+
+subroutine sub1(n)
+ implicit none
+ integer :: n, m, i, j
+ integer, dimension(n) :: a
+!CHECK: %[[LB1:.*]] = arith.constant 1 : i32
+!CHECK: %[[LB1_CVT:.*]] = fir.convert %[[LB1]] : (i32) -> index
+!CHECK: %[[UB1:.*]] = fir.load %5#0 : !fir.ref<i32>
+!CHECK: %[[UB1_CVT:.*]] = fir.convert %[[UB1]] : (i32) -> index
+!CHECK: %[[LB2:.*]] = arith.constant 1 : i32
+!CHECK: %[[LB2_CVT:.*]] = fir.convert %[[LB2]] : (i32) -> index
+!CHECK: %[[UB2:.*]] = fir.call @_QPbar(%{{.*}}, %{{.*}}) proc_attrs<pure> fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> i32
+!CHECK: %[[UB2_CVT:.*]] = fir.convert %[[UB2]] : (i32) -> index
+!CHECK: fir.do_loop %{{.*}} = %[[LB1_CVT]] to %[[UB1_CVT]] step %{{.*}} unordered
+!CHECK: fir.do_loop %{{.*}} = %[[LB2_CVT]] to %[[UB2_CVT]] step %{{.*}} unordered
+ do concurrent(i=1:n, j=1:bar(n*m, n/m))
+ a(i) = n
+ end do
+end subroutine
+
+subroutine sub2(n)
+ implicit none
+ integer :: n, m, i, j
+ integer, dimension(n) :: a
+!CHECK: %[[LB1:.*]] = arith.constant 1 : i32
+!CHECK: %[[LB1_CVT:.*]] = fir.convert %[[LB1]] : (i32) -> index
+!CHECK: %[[UB1:.*]] = fir.load %5#0 : !fir.ref<i32>
+!CHECK: %[[UB1_CVT:.*]] = fir.convert %[[UB1]] : (i32) -> index
+!CHECK: %[[LB2:.*]] = arith.constant 1 : i32
+!CHECK: %[[LB2_CVT:.*]] = fir.convert %[[LB2]] : (i32) -> index
+!CHECK: %[[UB2:.*]] = fir.call @_QPbar(%{{.*}}, %{{.*}}) proc_attrs<pure> fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> i32
+!CHECK: %[[UB2_CVT:.*]] = fir.convert %[[UB2]] : (i32) -> index
+!CHECK: fir.do_loop %{{.*}} = %[[LB1_CVT]] to %[[UB1_CVT]] step %{{.*}} unordered
+!CHECK: fir.do_loop %{{.*}} = %[[LB2_CVT]] to %[[UB2_CVT]] step %{{.*}} unordered
+ do concurrent(i=1:n)
+ do concurrent(j=1:bar(n*m, n/m))
+ a(i) = n
+ end do
+ end do
+end subroutine
More information about the flang-commits
mailing list