[flang-commits] [flang] [flang][openacc] Support early return in acc.loop (PR #73841)
via flang-commits
flang-commits at lists.llvm.org
Wed Nov 29 10:31:21 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-openacc
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Early return is accepted in OpenACC loop not directly nested in a compute construct. Since acc.loop operation has a region, the `func.return` operation cannot be directly used inside the region.
An early return is materialized by an `acc.yield` operation returning a `true` value. The standard end of the `acc.loop` region yield a `false` value in this case.
A conditional branch operation on the `acc.loop` result will branch to the `finalBlock` or just to the continue block whether an early exit was produce in the acc.loop.
---
Full diff: https://github.com/llvm/llvm-project/pull/73841.diff
4 Files Affected:
- (modified) flang/include/flang/Lower/OpenACC.h (+10-3)
- (modified) flang/lib/Lower/Bridge.cpp (+21-2)
- (modified) flang/lib/Lower/OpenACC.cpp (+80-14)
- (added) flang/test/Lower/OpenACC/acc-loop-exit.f90 (+37)
``````````diff
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 409956f0ecb309f..f23e4726f33e00a 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -64,9 +64,10 @@ static constexpr llvm::StringRef declarePreDeallocSuffix =
static constexpr llvm::StringRef declarePostDeallocSuffix =
"_acc_declare_update_desc_post_dealloc";
-void genOpenACCConstruct(AbstractConverter &,
- Fortran::semantics::SemanticsContext &,
- pft::Evaluation &, const parser::OpenACCConstruct &);
+mlir::Value genOpenACCConstruct(AbstractConverter &,
+ Fortran::semantics::SemanticsContext &,
+ pft::Evaluation &,
+ const parser::OpenACCConstruct &);
void genOpenACCDeclarativeConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
StatementContext &,
@@ -112,6 +113,12 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
mlir::Location);
+bool isInOpenACCLoop(fir::FirOpBuilder &);
+
+void setInsertionPointAfterOpenACCLoopIfInside(fir::FirOpBuilder &);
+
+void genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &, mlir::Location);
+
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 23c48cc7bd97874..45da1355df168e2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2382,11 +2382,25 @@ class FirConverter : public Fortran::lower::AbstractConverter {
void genFIR(const Fortran::parser::OpenACCConstruct &acc) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
- genOpenACCConstruct(*this, bridge.getSemanticsContext(), getEval(), acc);
+ mlir::Value exitCond = genOpenACCConstruct(
+ *this, bridge.getSemanticsContext(), getEval(), acc);
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
genFIR(e);
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);
+
+ const Fortran::parser::OpenACCLoopConstruct *accLoop =
+ std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
+ if (accLoop && exitCond) {
+ Fortran::lower::pft::FunctionLikeUnit *funit =
+ getEval().getOwningProcedure();
+ assert(funit && "not inside main program, function or subroutine");
+ mlir::Block *continueBlock =
+ builder->getBlock()->splitBlock(builder->getBlock()->end());
+ builder->create<mlir::cf::CondBranchOp>(toLocation(), exitCond,
+ funit->finalBlock, continueBlock);
+ builder->setInsertionPointToEnd(continueBlock);
+ }
}
void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) {
@@ -4091,10 +4105,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Branch to the last block of the SUBROUTINE, which has the actual return.
if (!funit->finalBlock) {
mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
+ Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(*builder);
funit->finalBlock = builder->createBlock(&builder->getRegion());
builder->restoreInsertionPoint(insPt);
}
- builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
+
+ if (Fortran::lower::isInOpenACCLoop(*builder))
+ Fortran::lower::genEarlyReturnInOpenACCLoop(*builder, loc);
+ else
+ builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
}
void genFIR(const Fortran::parser::CycleStmt &) {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 8c6c22210cf0894..e2abed1b9f4f675 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -25,10 +25,12 @@
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/expression.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "llvm/Frontend/OpenACC/ACC.h.inc"
// Special value for * passed in device_type or gang clauses.
@@ -1381,9 +1383,10 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::pft::Evaluation &eval,
const llvm::SmallVectorImpl<mlir::Value> &operands,
const llvm::SmallVectorImpl<int32_t> &operandSegments,
- bool outerCombined = false) {
- llvm::ArrayRef<mlir::Type> argTy;
- Op op = builder.create<Op>(loc, argTy, operands);
+ bool outerCombined = false,
+ llvm::SmallVector<mlir::Type> retTy = {},
+ mlir::Value yieldValue = {}) {
+ Op op = builder.create<Op>(loc, retTy, operands);
builder.createBlock(&op.getRegion());
mlir::Block &block = op.getRegion().back();
builder.setInsertionPointToStart(&block);
@@ -1401,7 +1404,16 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::acc::YieldOp>(
builder, eval.getNestedEvaluations());
- builder.create<Terminator>(loc);
+ if (yieldValue) {
+ if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
+ Terminator yieldOp = builder.create<Terminator>(loc, yieldValue);
+ yieldValue.getDefiningOp()->moveBefore(yieldOp);
+ } else {
+ builder.create<Terminator>(loc);
+ }
+ } else {
+ builder.create<Terminator>(loc);
+ }
builder.setInsertionPointToStart(&block);
return op;
}
@@ -1494,7 +1506,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::AccClauseList &accClauseList) {
+ const Fortran::parser::AccClauseList &accClauseList,
+ bool needEarlyReturnHandling = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Value workerNum;
@@ -1599,8 +1612,17 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
addOperands(operands, operandSegments, privateOperands);
addOperands(operands, operandSegments, reductionOperands);
+ llvm::SmallVector<mlir::Type> retTy;
+ mlir::Value yieldValue;
+ if (needEarlyReturnHandling) {
+ mlir::Type i1Ty = builder.getI1Type();
+ yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
+ retTy.push_back(i1Ty);
+ }
+
auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
- builder, currentLocation, eval, operands, operandSegments);
+ builder, currentLocation, eval, operands, operandSegments,
+ /*outerCombined=*/false, retTy, yieldValue);
if (hasGang)
loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1647,16 +1669,34 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
return loopOp;
}
-static void genACC(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
+static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
+ bool hasReturnStmt = false;
+ for (auto &e : eval.getNestedEvaluations()) {
+ e.visit(Fortran::common::visitors{
+ [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
+ [&](const auto &s) {},
+ });
+ if (e.hasNestedEvaluations())
+ hasReturnStmt = hasEarlyReturn(e);
+ }
+ return hasReturnStmt;
+}
+
+static mlir::Value
+genACC(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
const auto &beginLoopDirective =
std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
const auto &loopDirective =
std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
+ bool needEarlyExitHandling = false;
+ if (eval.lowerAsUnstructured())
+ needEarlyExitHandling = hasEarlyReturn(eval);
+
mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);
Fortran::lower::StatementContext stmtCtx;
@@ -1664,9 +1704,13 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
if (loopDirective.v == llvm::acc::ACCD_loop) {
const auto &accClauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
- createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
- accClauseList);
+ auto loopOp =
+ createLoopOp(converter, currentLocation, eval, semanticsContext,
+ stmtCtx, accClauseList, needEarlyExitHandling);
+ if (needEarlyExitHandling)
+ return loopOp.getResult(0);
}
+ return mlir::Value{};
}
template <typename Op, typename Clause>
@@ -3431,12 +3475,13 @@ genACC(Fortran::lower::AbstractConverter &converter,
builder.restoreInsertionPoint(crtPos);
}
-void Fortran::lower::genOpenACCConstruct(
+mlir::Value Fortran::lower::genOpenACCConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenACCConstruct &accConstruct) {
+ mlir::Value exitCond;
std::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
@@ -3447,7 +3492,7 @@ void Fortran::lower::genOpenACCConstruct(
genACC(converter, semanticsContext, eval, combinedConstruct);
},
[&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
- genACC(converter, semanticsContext, eval, loopConstruct);
+ exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
},
[&](const Fortran::parser::OpenACCStandaloneConstruct
&standaloneConstruct) {
@@ -3467,6 +3512,7 @@ void Fortran::lower::genOpenACCConstruct(
},
},
accConstruct.u);
+ return exitCond;
}
void Fortran::lower::genOpenACCDeclarativeConstruct(
@@ -3560,3 +3606,23 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
else
builder.create<mlir::acc::TerminatorOp>(loc);
}
+
+bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
+ if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+ return true;
+ return false;
+}
+
+void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
+ fir::FirOpBuilder &builder) {
+ if (auto loopOp =
+ builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+ builder.setInsertionPointAfter(loopOp);
+}
+
+void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
+ mlir::Location loc) {
+ mlir::Value yieldValue =
+ builder.createIntegerConstant(loc, builder.getI1Type(), 1);
+ builder.create<mlir::acc::YieldOp>(loc, yieldValue);
+}
diff --git a/flang/test/Lower/OpenACC/acc-loop-exit.f90 b/flang/test/Lower/OpenACC/acc-loop-exit.f90
new file mode 100644
index 000000000000000..75f1c3073327228
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-exit.f90
@@ -0,0 +1,37 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1(x, a)
+ real :: x(200)
+ integer :: a
+
+ !$acc loop
+ do i = 100, 200
+ x(i) = 1.0
+ if (i == a) return
+ end do
+
+ i = 2
+end
+
+! CHECK-LABEL: func.func @_QPsub1
+! CHECK: %[[A:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFsub1Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[EXIT_COND:.*]] = acc.loop {
+! CHECK: ^bb{{.*}}:
+! CHECK: ^bb{{.*}}:
+! CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+! CHECK: %[[CMP:.*]] = arith.cmpi eq, %15, %[[LOAD_A]] : i32
+! CHECK: cf.cond_br %[[CMP]], ^[[EARLY_RET:.*]], ^[[NO_RET:.*]]
+! CHECK: ^[[EARLY_RET]]:
+! CHECK: acc.yield %true : i1
+! CHECK: ^[[NO_RET]]:
+! CHECK: cf.br ^bb{{.*}}
+! CHECK: ^bb{{.*}}:
+! CHECK: acc.yield %false : i1
+! CHECK: }(i1)
+! CHECK: cf.cond_br %[[EXIT_COND]], ^[[EXIT_BLOCK:.*]], ^[[CONTINUE_BLOCK:.*]]
+! CHECK: ^[[CONTINUE_BLOCK]]:
+! CHECK: hlfir.assign
+! CHECK: cf.br ^[[EXIT_BLOCK]]
+! CHECK: ^[[EXIT_BLOCK]]:
+! CHECK: return
+! CHECK: }
``````````
</details>
https://github.com/llvm/llvm-project/pull/73841
More information about the flang-commits
mailing list