[flang-commits] [flang] [flang][openacc] Support early return in acc.loop (PR #73841)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Nov 29 10:30:54 PST 2023


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/73841

Early return is accepted in OpenACC loop not directly nested in a compute construct. Since acc.loop operation has a region, the `func.return` operation cannot be directly used inside the region.
An early return is materialized by an `acc.yield` operation returning a `true` value. The standard end of the `acc.loop` region yield a `false` value in this case. 
A conditional branch operation on the `acc.loop` result will branch to the `finalBlock` or just to the continue block whether an early exit was produce in the acc.loop. 


>From 4a19fe70b1bb945d0647047ba9356aa338e15eef Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 28 Nov 2023 21:39:43 -0800
Subject: [PATCH] [flang][openacc] Support early return in acc.loop

---
 flang/include/flang/Lower/OpenACC.h        | 13 ++-
 flang/lib/Lower/Bridge.cpp                 | 23 +++++-
 flang/lib/Lower/OpenACC.cpp                | 94 ++++++++++++++++++----
 flang/test/Lower/OpenACC/acc-loop-exit.f90 | 37 +++++++++
 4 files changed, 148 insertions(+), 19 deletions(-)
 create mode 100644 flang/test/Lower/OpenACC/acc-loop-exit.f90

diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 409956f0ecb309f..f23e4726f33e00a 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -64,9 +64,10 @@ static constexpr llvm::StringRef declarePreDeallocSuffix =
 static constexpr llvm::StringRef declarePostDeallocSuffix =
     "_acc_declare_update_desc_post_dealloc";
 
-void genOpenACCConstruct(AbstractConverter &,
-                         Fortran::semantics::SemanticsContext &,
-                         pft::Evaluation &, const parser::OpenACCConstruct &);
+mlir::Value genOpenACCConstruct(AbstractConverter &,
+                                Fortran::semantics::SemanticsContext &,
+                                pft::Evaluation &,
+                                const parser::OpenACCConstruct &);
 void genOpenACCDeclarativeConstruct(AbstractConverter &,
                                     Fortran::semantics::SemanticsContext &,
                                     StatementContext &,
@@ -112,6 +113,12 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
 void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
                           mlir::Location);
 
+bool isInOpenACCLoop(fir::FirOpBuilder &);
+
+void setInsertionPointAfterOpenACCLoopIfInside(fir::FirOpBuilder &);
+
+void genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &, mlir::Location);
+
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 23c48cc7bd97874..45da1355df168e2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2382,11 +2382,25 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void genFIR(const Fortran::parser::OpenACCConstruct &acc) {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     localSymbols.pushScope();
-    genOpenACCConstruct(*this, bridge.getSemanticsContext(), getEval(), acc);
+    mlir::Value exitCond = genOpenACCConstruct(
+        *this, bridge.getSemanticsContext(), getEval(), acc);
     for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
       genFIR(e);
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);
+
+    const Fortran::parser::OpenACCLoopConstruct *accLoop =
+        std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
+    if (accLoop && exitCond) {
+      Fortran::lower::pft::FunctionLikeUnit *funit =
+          getEval().getOwningProcedure();
+      assert(funit && "not inside main program, function or subroutine");
+      mlir::Block *continueBlock =
+          builder->getBlock()->splitBlock(builder->getBlock()->end());
+      builder->create<mlir::cf::CondBranchOp>(toLocation(), exitCond,
+                                              funit->finalBlock, continueBlock);
+      builder->setInsertionPointToEnd(continueBlock);
+    }
   }
 
   void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) {
@@ -4091,10 +4105,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     // Branch to the last block of the SUBROUTINE, which has the actual return.
     if (!funit->finalBlock) {
       mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
+      Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(*builder);
       funit->finalBlock = builder->createBlock(&builder->getRegion());
       builder->restoreInsertionPoint(insPt);
     }
-    builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
+
+    if (Fortran::lower::isInOpenACCLoop(*builder))
+      Fortran::lower::genEarlyReturnInOpenACCLoop(*builder, loc);
+    else
+      builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
   }
 
   void genFIR(const Fortran::parser::CycleStmt &) {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 8c6c22210cf0894..e2abed1b9f4f675 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -25,10 +25,12 @@
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/IntrinsicCall.h"
 #include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Parser/parse-tree-visitor.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/expression.h"
 #include "flang/Semantics/scope.h"
 #include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "llvm/Frontend/OpenACC/ACC.h.inc"
 
 // Special value for * passed in device_type or gang clauses.
@@ -1381,9 +1383,10 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
                          Fortran::lower::pft::Evaluation &eval,
                          const llvm::SmallVectorImpl<mlir::Value> &operands,
                          const llvm::SmallVectorImpl<int32_t> &operandSegments,
-                         bool outerCombined = false) {
-  llvm::ArrayRef<mlir::Type> argTy;
-  Op op = builder.create<Op>(loc, argTy, operands);
+                         bool outerCombined = false,
+                         llvm::SmallVector<mlir::Type> retTy = {},
+                         mlir::Value yieldValue = {}) {
+  Op op = builder.create<Op>(loc, retTy, operands);
   builder.createBlock(&op.getRegion());
   mlir::Block &block = op.getRegion().back();
   builder.setInsertionPointToStart(&block);
@@ -1401,7 +1404,16 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
                                             mlir::acc::YieldOp>(
         builder, eval.getNestedEvaluations());
 
-  builder.create<Terminator>(loc);
+  if (yieldValue) {
+    if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
+      Terminator yieldOp = builder.create<Terminator>(loc, yieldValue);
+      yieldValue.getDefiningOp()->moveBefore(yieldOp);
+    } else {
+      builder.create<Terminator>(loc);
+    }
+  } else {
+    builder.create<Terminator>(loc);
+  }
   builder.setInsertionPointToStart(&block);
   return op;
 }
@@ -1494,7 +1506,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
              Fortran::lower::pft::Evaluation &eval,
              Fortran::semantics::SemanticsContext &semanticsContext,
              Fortran::lower::StatementContext &stmtCtx,
-             const Fortran::parser::AccClauseList &accClauseList) {
+             const Fortran::parser::AccClauseList &accClauseList,
+             bool needEarlyReturnHandling = false) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
   mlir::Value workerNum;
@@ -1599,8 +1612,17 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   addOperands(operands, operandSegments, privateOperands);
   addOperands(operands, operandSegments, reductionOperands);
 
+  llvm::SmallVector<mlir::Type> retTy;
+  mlir::Value yieldValue;
+  if (needEarlyReturnHandling) {
+    mlir::Type i1Ty = builder.getI1Type();
+    yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
+    retTy.push_back(i1Ty);
+  }
+
   auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
-      builder, currentLocation, eval, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments,
+      /*outerCombined=*/false, retTy, yieldValue);
 
   if (hasGang)
     loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1647,16 +1669,34 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   return loopOp;
 }
 
-static void genACC(Fortran::lower::AbstractConverter &converter,
-                   Fortran::semantics::SemanticsContext &semanticsContext,
-                   Fortran::lower::pft::Evaluation &eval,
-                   const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
+static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
+  bool hasReturnStmt = false;
+  for (auto &e : eval.getNestedEvaluations()) {
+    e.visit(Fortran::common::visitors{
+        [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
+        [&](const auto &s) {},
+    });
+    if (e.hasNestedEvaluations())
+      hasReturnStmt = hasEarlyReturn(e);
+  }
+  return hasReturnStmt;
+}
+
+static mlir::Value
+genACC(Fortran::lower::AbstractConverter &converter,
+       Fortran::semantics::SemanticsContext &semanticsContext,
+       Fortran::lower::pft::Evaluation &eval,
+       const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
 
   const auto &beginLoopDirective =
       std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
   const auto &loopDirective =
       std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
 
+  bool needEarlyExitHandling = false;
+  if (eval.lowerAsUnstructured())
+    needEarlyExitHandling = hasEarlyReturn(eval);
+
   mlir::Location currentLocation =
       converter.genLocation(beginLoopDirective.source);
   Fortran::lower::StatementContext stmtCtx;
@@ -1664,9 +1704,13 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   if (loopDirective.v == llvm::acc::ACCD_loop) {
     const auto &accClauseList =
         std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
+    auto loopOp =
+        createLoopOp(converter, currentLocation, eval, semanticsContext,
+                     stmtCtx, accClauseList, needEarlyExitHandling);
+    if (needEarlyExitHandling)
+      return loopOp.getResult(0);
   }
+  return mlir::Value{};
 }
 
 template <typename Op, typename Clause>
@@ -3431,12 +3475,13 @@ genACC(Fortran::lower::AbstractConverter &converter,
   builder.restoreInsertionPoint(crtPos);
 }
 
-void Fortran::lower::genOpenACCConstruct(
+mlir::Value Fortran::lower::genOpenACCConstruct(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semanticsContext,
     Fortran::lower::pft::Evaluation &eval,
     const Fortran::parser::OpenACCConstruct &accConstruct) {
 
+  mlir::Value exitCond;
   std::visit(
       common::visitors{
           [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
@@ -3447,7 +3492,7 @@ void Fortran::lower::genOpenACCConstruct(
             genACC(converter, semanticsContext, eval, combinedConstruct);
           },
           [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
-            genACC(converter, semanticsContext, eval, loopConstruct);
+            exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
           },
           [&](const Fortran::parser::OpenACCStandaloneConstruct
                   &standaloneConstruct) {
@@ -3467,6 +3512,7 @@ void Fortran::lower::genOpenACCConstruct(
           },
       },
       accConstruct.u);
+  return exitCond;
 }
 
 void Fortran::lower::genOpenACCDeclarativeConstruct(
@@ -3560,3 +3606,23 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
   else
     builder.create<mlir::acc::TerminatorOp>(loc);
 }
+
+bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
+  if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+    return true;
+  return false;
+}
+
+void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
+    fir::FirOpBuilder &builder) {
+  if (auto loopOp =
+          builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+    builder.setInsertionPointAfter(loopOp);
+}
+
+void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
+                                                 mlir::Location loc) {
+  mlir::Value yieldValue =
+      builder.createIntegerConstant(loc, builder.getI1Type(), 1);
+  builder.create<mlir::acc::YieldOp>(loc, yieldValue);
+}
diff --git a/flang/test/Lower/OpenACC/acc-loop-exit.f90 b/flang/test/Lower/OpenACC/acc-loop-exit.f90
new file mode 100644
index 000000000000000..75f1c3073327228
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-exit.f90
@@ -0,0 +1,37 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1(x, a)
+  real :: x(200)
+  integer :: a
+
+  !$acc loop
+  do i = 100, 200
+    x(i) = 1.0
+    if (i == a) return
+  end do
+
+  i = 2
+end 
+
+! CHECK-LABEL: func.func @_QPsub1
+! CHECK: %[[A:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFsub1Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[EXIT_COND:.*]] = acc.loop {
+! CHECK: ^bb{{.*}}:
+! CHECK: ^bb{{.*}}:
+! CHECK:   %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+! CHECK:   %[[CMP:.*]] = arith.cmpi eq, %15, %[[LOAD_A]] : i32
+! CHECK:   cf.cond_br %[[CMP]], ^[[EARLY_RET:.*]], ^[[NO_RET:.*]]
+! CHECK: ^[[EARLY_RET]]:
+! CHECK:   acc.yield %true : i1
+! CHECK: ^[[NO_RET]]:
+! CHECK:   cf.br ^bb{{.*}}
+! CHECK: ^bb{{.*}}:
+! CHECK:   acc.yield %false : i1
+! CHECK: }(i1)
+! CHECK: cf.cond_br %[[EXIT_COND]], ^[[EXIT_BLOCK:.*]], ^[[CONTINUE_BLOCK:.*]]
+! CHECK: ^[[CONTINUE_BLOCK]]:
+! CHECK:   hlfir.assign
+! CHECK:   cf.br ^[[EXIT_BLOCK]]
+! CHECK: ^[[EXIT_BLOCK]]:
+! CHECK:   return
+! CHECK: }



More information about the flang-commits mailing list