[flang-commits] [flang] b1341e2 - [flang][openacc] Fix unstructured code in OpenACC region ops (#66284)

via flang-commits flang-commits at lists.llvm.org
Wed Sep 13 20:49:59 PDT 2023


Author: Valentin Clement (バレンタイン クレメン)
Date: 2023-09-13T20:49:54-07:00
New Revision: b1341e2863e7aae035827dd3f8a2373fa2617c8a

URL: https://github.com/llvm/llvm-project/commit/b1341e2863e7aae035827dd3f8a2373fa2617c8a
DIFF: https://github.com/llvm/llvm-project/commit/b1341e2863e7aae035827dd3f8a2373fa2617c8a.diff

LOG: [flang][openacc] Fix unstructured code in OpenACC region ops (#66284)

For unstructured construct, the blocks are created in advance inside the
function body. This causes issues when the unstructured construct is
inside an OpenACC region operations. This patch adds the same fix than
OpenMP lowering and re-create the blocks inside the op region.

Initial OpenMP fix: 29f167abcf7d871d17dd3f38f361916de1a12470

Added: 
    flang/test/Lower/OpenACC/acc-unstructured.f90

Modified: 
    flang/lib/Lower/DirectivesCommon.h
    flang/lib/Lower/OpenACC.cpp
    flang/lib/Lower/OpenMP.cpp
    flang/test/Lower/OpenACC/stop-stmt-in-region.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 35825a20b4cf93f..efac311bec83338 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -587,7 +587,31 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
   firOpBuilder.setInsertionPointToStart(&block);
 }
 
+/// Create empty blocks for the current region.
+/// These blocks replace blocks parented to an enclosing region.
+template <typename... TerminatorOps>
+void createEmptyRegionBlocks(
+    fir::FirOpBuilder &builder,
+    std::list<Fortran::lower::pft::Evaluation> &evaluationList) {
+  mlir::Region *region = &builder.getRegion();
+  for (Fortran::lower::pft::Evaluation &eval : evaluationList) {
+    if (eval.block) {
+      if (eval.block->empty()) {
+        eval.block->erase();
+        eval.block = builder.createBlock(region);
+      } else {
+        [[maybe_unused]] mlir::Operation &terminatorOp = eval.block->back();
+        assert(mlir::isa<TerminatorOps...>(terminatorOp) &&
+               "expected terminator op");
+      }
+    }
+    if (!eval.isDirective() && eval.hasNestedEvaluations())
+      createEmptyRegionBlocks<TerminatorOps...>(builder,
+                                                eval.getNestedEvaluations());
+  }
+}
+
 } // namespace lower
 } // namespace Fortran
 
-#endif // FORTRAN_LOWER_DIRECTIVES_COMMON_H
\ No newline at end of file
+#endif // FORTRAN_LOWER_DIRECTIVES_COMMON_H

diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 732765c4def59cb..180c077d39c9ee2 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1250,16 +1250,16 @@ static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
 }
 
 template <typename Op, typename Terminator>
-static Op
-createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
-               const llvm::SmallVectorImpl<mlir::Value> &operands,
-               const llvm::SmallVectorImpl<int32_t> &operandSegments) {
+static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
+                         Fortran::lower::pft::Evaluation &eval,
+                         const llvm::SmallVectorImpl<mlir::Value> &operands,
+                         const llvm::SmallVectorImpl<int32_t> &operandSegments,
+                         bool outerCombined = false) {
   llvm::ArrayRef<mlir::Type> argTy;
   Op op = builder.create<Op>(loc, argTy, operands);
   builder.createBlock(&op.getRegion());
   mlir::Block &block = op.getRegion().back();
   builder.setInsertionPointToStart(&block);
-  builder.create<Terminator>(loc);
 
   op->setAttr(Op::getOperandSegmentSizeAttr(),
               builder.getDenseI32ArrayAttr(operandSegments));
@@ -1267,6 +1267,15 @@ createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
   // Place the insertion point to the start of the first block.
   builder.setInsertionPointToStart(&block);
 
+  // If it is an unstructured region and is not the outer region of a combined
+  // construct, create empty blocks for all evaluations.
+  if (eval.lowerAsUnstructured() && !outerCombined)
+    Fortran::lower::createEmptyRegionBlocks<mlir::acc::TerminatorOp,
+                                            mlir::acc::YieldOp>(
+        builder, eval.getNestedEvaluations());
+
+  builder.create<Terminator>(loc);
+  builder.setInsertionPointToStart(&block);
   return op;
 }
 
@@ -1347,6 +1356,7 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
 static mlir::acc::LoopOp
 createLoopOp(Fortran::lower::AbstractConverter &converter,
              mlir::Location currentLocation,
+             Fortran::lower::pft::Evaluation &eval,
              Fortran::semantics::SemanticsContext &semanticsContext,
              Fortran::lower::StatementContext &stmtCtx,
              const Fortran::parser::AccClauseList &accClauseList) {
@@ -1455,7 +1465,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   addOperands(operands, operandSegments, cacheOperands);
 
   auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
-      builder, currentLocation, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments);
 
   if (hasGang)
     loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1504,6 +1514,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
 
 static void genACC(Fortran::lower::AbstractConverter &converter,
                    Fortran::semantics::SemanticsContext &semanticsContext,
+                   Fortran::lower::pft::Evaluation &eval,
                    const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
 
   const auto &beginLoopDirective =
@@ -1518,7 +1529,7 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   if (loopDirective.v == llvm::acc::ACCD_loop) {
     const auto &accClauseList =
         std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   }
 }
@@ -1551,9 +1562,11 @@ template <typename Op>
 static Op
 createComputeOp(Fortran::lower::AbstractConverter &converter,
                 mlir::Location currentLocation,
+                Fortran::lower::pft::Evaluation &eval,
                 Fortran::semantics::SemanticsContext &semanticsContext,
                 Fortran::lower::StatementContext &stmtCtx,
-                const Fortran::parser::AccClauseList &accClauseList) {
+                const Fortran::parser::AccClauseList &accClauseList,
+                bool outerCombined = false) {
 
   // Parallel operation operands
   mlir::Value async;
@@ -1769,10 +1782,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
   Op computeOp;
   if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
     computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
-        builder, currentLocation, operands, operandSegments);
+        builder, currentLocation, eval, operands, operandSegments,
+        outerCombined);
   else
     computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
-        builder, currentLocation, operands, operandSegments);
+        builder, currentLocation, eval, operands, operandSegments,
+        outerCombined);
 
   if (addAsyncAttr)
     computeOp.setAsyncAttrAttr(builder.getUnitAttr());
@@ -1817,6 +1832,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
 static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
                          mlir::Location currentLocation,
+                         Fortran::lower::pft::Evaluation &eval,
                          Fortran::semantics::SemanticsContext &semanticsContext,
                          Fortran::lower::StatementContext &stmtCtx,
                          const Fortran::parser::AccClauseList &accClauseList) {
@@ -1942,7 +1958,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
     return;
 
   auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
-      builder, currentLocation, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments);
 
   dataOp.setAsyncAttr(addAsyncAttr);
   dataOp.setWaitAttr(addWaitAttr);
@@ -1971,6 +1987,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
 static void
 genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
                  mlir::Location currentLocation,
+                 Fortran::lower::pft::Evaluation &eval,
                  Fortran::semantics::SemanticsContext &semanticsContext,
                  Fortran::lower::StatementContext &stmtCtx,
                  const Fortran::parser::AccClauseList &accClauseList) {
@@ -2020,7 +2037,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
 
   auto hostDataOp =
       createRegionOp<mlir::acc::HostDataOp, mlir::acc::TerminatorOp>(
-          builder, currentLocation, operands, operandSegments);
+          builder, currentLocation, eval, operands, operandSegments);
 
   if (addIfPresentAttr)
     hostDataOp.setIfPresentAttr(builder.getUnitAttr());
@@ -2029,6 +2046,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
 static void
 genACC(Fortran::lower::AbstractConverter &converter,
        Fortran::semantics::SemanticsContext &semanticsContext,
+       Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
   const auto &beginBlockDirective =
       std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
@@ -2041,26 +2059,30 @@ genACC(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
 
   if (blockDirective.v == llvm::acc::ACCD_parallel) {
-    createComputeOp<mlir::acc::ParallelOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
+    createComputeOp<mlir::acc::ParallelOp>(converter, currentLocation, eval,
+                                           semanticsContext, stmtCtx,
+                                           accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_data) {
-    genACCDataOp(converter, currentLocation, semanticsContext, stmtCtx,
+    genACCDataOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_serial) {
-    createComputeOp<mlir::acc::SerialOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
+    createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
+                                         semanticsContext, stmtCtx,
+                                         accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_kernels) {
-    createComputeOp<mlir::acc::KernelsOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
+    createComputeOp<mlir::acc::KernelsOp>(converter, currentLocation, eval,
+                                          semanticsContext, stmtCtx,
+                                          accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_host_data) {
-    genACCHostDataOp(converter, currentLocation, semanticsContext, stmtCtx,
-                     accClauseList);
+    genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
+                     stmtCtx, accClauseList);
   }
 }
 
 static void
 genACC(Fortran::lower::AbstractConverter &converter,
        Fortran::semantics::SemanticsContext &semanticsContext,
+       Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) {
   const auto &beginCombinedDirective =
       std::get<Fortran::parser::AccBeginCombinedDirective>(combinedConstruct.t);
@@ -2075,18 +2097,21 @@ genACC(Fortran::lower::AbstractConverter &converter,
 
   if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
     createComputeOp<mlir::acc::KernelsOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+        converter, currentLocation, eval, semanticsContext, stmtCtx,
+        accClauseList, /*outerCombined=*/true);
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
     createComputeOp<mlir::acc::ParallelOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+        converter, currentLocation, eval, semanticsContext, stmtCtx,
+        accClauseList, /*outerCombined=*/true);
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
-    createComputeOp<mlir::acc::SerialOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+    createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
+                                         semanticsContext, stmtCtx,
+                                         accClauseList, /*outerCombined=*/true);
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else {
     llvm::report_fatal_error("Unknown combined construct encountered");
@@ -3169,14 +3194,14 @@ void Fortran::lower::genOpenACCConstruct(
   std::visit(
       common::visitors{
           [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
-            genACC(converter, semanticsContext, blockConstruct);
+            genACC(converter, semanticsContext, eval, blockConstruct);
           },
           [&](const Fortran::parser::OpenACCCombinedConstruct
                   &combinedConstruct) {
-            genACC(converter, semanticsContext, combinedConstruct);
+            genACC(converter, semanticsContext, eval, combinedConstruct);
           },
           [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
-            genACC(converter, semanticsContext, loopConstruct);
+            genACC(converter, semanticsContext, eval, loopConstruct);
           },
           [&](const Fortran::parser::OpenACCStandaloneConstruct
                   &standaloneConstruct) {

diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index b960bb369dd4dd2..be9c0dcdbdbf485 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1981,29 +1981,6 @@ static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
   return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
 }
 
-/// Create empty blocks for the current region.
-/// These blocks replace blocks parented to an enclosing region.
-static void createEmptyRegionBlocks(
-    fir::FirOpBuilder &firOpBuilder,
-    std::list<Fortran::lower::pft::Evaluation> &evaluationList) {
-  mlir::Region *region = &firOpBuilder.getRegion();
-  for (Fortran::lower::pft::Evaluation &eval : evaluationList) {
-    if (eval.block) {
-      if (eval.block->empty()) {
-        eval.block->erase();
-        eval.block = firOpBuilder.createBlock(region);
-      } else {
-        [[maybe_unused]] mlir::Operation &terminatorOp = eval.block->back();
-        assert((mlir::isa<mlir::omp::TerminatorOp>(terminatorOp) ||
-                mlir::isa<mlir::omp::YieldOp>(terminatorOp)) &&
-               "expected terminator op");
-      }
-    }
-    if (!eval.isDirective() && eval.hasNestedEvaluations())
-      createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
-  }
-}
-
 static void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder,
                                   mlir::Operation *storeOp,
                                   mlir::Block &block) {
@@ -2092,7 +2069,9 @@ static void createBodyOfOp(
   // If it is an unstructured region and is not the outer region of a combined
   // construct, create empty blocks for all evaluations.
   if (eval.lowerAsUnstructured() && !outerCombined)
-    createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
+    Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
+                                            mlir::omp::YieldOp>(
+        firOpBuilder, eval.getNestedEvaluations());
 
   // Insert the terminator.
   if constexpr (std::is_same_v<Op, mlir::omp::WsLoopOp> ||

diff  --git a/flang/test/Lower/OpenACC/acc-unstructured.f90 b/flang/test/Lower/OpenACC/acc-unstructured.f90
new file mode 100644
index 000000000000000..bd9f3284d9fc25c
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-unstructured.f90
@@ -0,0 +1,86 @@
+! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine test_unstructured1(a, b, c)
+  integer :: i, j, k
+  real :: a(:,:,:), b(:,:,:), c(:,:,:)
+
+  !$acc data copy(a, b, c)
+
+  !$acc kernels
+  a(:,:,:) = 0.0
+  !$acc end kernels
+
+  !$acc kernels
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+      end do
+    end do
+  end do
+  !$acc end kernels
+
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+      end do
+    end do
+
+    if (a(1,2,3) > 10) stop 'just to be unstructured'
+  end do
+
+  !$acc end data
+
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_unstructured1
+! CHECK: acc.data
+! CHECK: acc.kernels
+! CHECK: acc.kernels
+! CHECK: fir.call @_FortranAStopStatementText
+
+
+subroutine test_unstructured2(a, b, c)
+  integer :: i, j, k
+  real :: a(:,:,:), b(:,:,:), c(:,:,:)
+
+  !$acc parallel loop
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+        if (a(1,2,3) > 10) stop 'just to be unstructured'
+      end do
+    end do
+  end do
+
+! CHECK-LABEL: func.func @_QPtest_unstructured2
+! CHECK: acc.parallel
+! CHECK: acc.loop
+! CHECK: fir.call @_FortranAStopStatementText
+! CHECK: fir.unreachable
+! CHECK: acc.yield
+! CHECK: acc.yield
+
+end subroutine
+
+subroutine test_unstructured3(a, b, c)
+  integer :: i, j, k
+  real :: a(:,:,:), b(:,:,:), c(:,:,:)
+
+  !$acc parallel
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+        if (a(1,2,3) > 10) stop 'just to be unstructured'
+      end do
+    end do
+  end do
+  !$acc end parallel
+
+! CHECK-LABEL: func.func @_QPtest_unstructured3
+! CHECK: acc.parallel
+! CHECK: fir.call @_FortranAStopStatementText
+! CHECK: fir.unreachable
+! CHECK: acc.yield
+
+end subroutine

diff  --git a/flang/test/Lower/OpenACC/stop-stmt-in-region.f90 b/flang/test/Lower/OpenACC/stop-stmt-in-region.f90
index 4b3e5632650f1cf..bec9d53b54c0f1d 100644
--- a/flang/test/Lower/OpenACC/stop-stmt-in-region.f90
+++ b/flang/test/Lower/OpenACC/stop-stmt-in-region.f90
@@ -29,7 +29,7 @@ subroutine test_stop_in_region1()
 ! CHECK:           %[[VAL_2:.*]] = arith.constant false
 ! CHECK:           %[[VAL_3:.*]] = arith.constant false
 ! CHECK:           %[[VAL_4:.*]] = fir.call @_FortranAStopStatement(%[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) {{.*}} : (i32, i1, i1) -> none
-! CHECK:           acc.yield
+! CHECK:           fir.unreachable
 ! CHECK:         }
 ! CHECK:         return
 ! CHECK:       }


        


More information about the flang-commits mailing list