[flang-commits] [flang] [flang][openacc] Fix unstructured code in OpenACC region ops (PR #66284)

via flang-commits flang-commits at lists.llvm.org
Wed Sep 13 13:18:18 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp
            
<details>
<summary>Changes</summary>
For unstructured construct, the blocks are created in advance inside the function body. This causes issues when the unstructured construct is inside an OpenACC region operations. This patch adds the same fix than OpenMP lowering and re-create the blocks inside the op region. 

Initial OpenMP fix: 29f167abcf7d871d17dd3f38f361916de1a12470
--
Full diff: https://github.com/llvm/llvm-project/pull/66284.diff

4 Files Affected:

- (modified) flang/lib/Lower/DirectivesCommon.h (+23-1) 
- (modified) flang/lib/Lower/OpenACC.cpp (+48-23) 
- (modified) flang/lib/Lower/OpenMP.cpp (+1-24) 
- (added) flang/test/Lower/OpenACC/acc-unstructured.f90 (+86) 


<pre>
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 35825a20b4cf93f..59d46008bcca9e3 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -587,7 +587,29 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
   firOpBuilder.setInsertionPointToStart(&block);
 }
 
+/// Create empty blocks for the current region.
+/// These blocks replace blocks parented to an enclosing region.
+template <typename... TerminatorOps>
+void createEmptyRegionBlocks(fir::FirOpBuilder &builder,
+    std::list<Fortran::lower::pft::Evaluation> &evaluationList) {
+  mlir::Region *region = &builder.getRegion();
+  for (Fortran::lower::pft::Evaluation &eval : evaluationList) {
+    if (eval.block) {
+      if (eval.block->empty()) {
+        eval.block->erase();
+        eval.block = builder.createBlock(region);
+      } else {
+        [[maybe_unused]] mlir::Operation &terminatorOp = eval.block->back();
+        assert(mlir::isa<TerminatorOps...>(terminatorOp) &&
+            "expected terminator op");
+      }
+    }
+    if (!eval.isDirective() && eval.hasNestedEvaluations())
+      createEmptyRegionBlocks<TerminatorOps...>(builder, eval.getNestedEvaluations());
+  }
+}
+
 } // namespace lower
 } // namespace Fortran
 
-#endif // FORTRAN_LOWER_DIRECTIVES_COMMON_H
\ No newline at end of file
+#endif // FORTRAN_LOWER_DIRECTIVES_COMMON_H
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 732765c4def59cb..e798876130c1d28 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1252,14 +1252,15 @@ static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
 template <typename Op, typename Terminator>
 static Op
 createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
+               Fortran::lower::pft::Evaluation &eval,
                const llvm::SmallVectorImpl<mlir::Value> &operands,
-               const llvm::SmallVectorImpl<int32_t> &operandSegments) {
+               const llvm::SmallVectorImpl<int32_t> &operandSegments,
+               bool outerCombined = false) {
   llvm::ArrayRef<mlir::Type> argTy;
   Op op = builder.create<Op>(loc, argTy, operands);
   builder.createBlock(&op.getRegion());
   mlir::Block &block = op.getRegion().back();
   builder.setInsertionPointToStart(&block);
-  builder.create<Terminator>(loc);
 
   op->setAttr(Op::getOperandSegmentSizeAttr(),
               builder.getDenseI32ArrayAttr(operandSegments));
@@ -1267,6 +1268,13 @@ createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
   // Place the insertion point to the start of the first block.
   builder.setInsertionPointToStart(&block);
 
+  // If it is an unstructured region and is not the outer region of a combined
+  // construct, create empty blocks for all evaluations.
+  if (eval.lowerAsUnstructured() && !outerCombined)
+    Fortran::lower::createEmptyRegionBlocks<mlir::acc::TerminatorOp, mlir::acc::YieldOp>(builder, eval.getNestedEvaluations());
+
+  builder.create<Terminator>(loc);
+  builder.setInsertionPointToStart(&block);
   return op;
 }
 
@@ -1347,6 +1355,7 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
 static mlir::acc::LoopOp
 createLoopOp(Fortran::lower::AbstractConverter &converter,
              mlir::Location currentLocation,
+             Fortran::lower::pft::Evaluation &eval,
              Fortran::semantics::SemanticsContext &semanticsContext,
              Fortran::lower::StatementContext &stmtCtx,
              const Fortran::parser::AccClauseList &accClauseList) {
@@ -1455,7 +1464,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   addOperands(operands, operandSegments, cacheOperands);
 
   auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
-      builder, currentLocation, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments);
 
   if (hasGang)
     loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1504,6 +1513,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
 
 static void genACC(Fortran::lower::AbstractConverter &converter,
                    Fortran::semantics::SemanticsContext &semanticsContext,
+                   Fortran::lower::pft::Evaluation &eval,
                    const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
 
   const auto &beginLoopDirective =
@@ -1518,7 +1528,7 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   if (loopDirective.v == llvm::acc::ACCD_loop) {
     const auto &accClauseList =
         std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   }
 }
@@ -1551,9 +1561,11 @@ template <typename Op>
 static Op
 createComputeOp(Fortran::lower::AbstractConverter &converter,
                 mlir::Location currentLocation,
+                Fortran::lower::pft::Evaluation &eval,
                 Fortran::semantics::SemanticsContext &semanticsContext,
                 Fortran::lower::StatementContext &stmtCtx,
-                const Fortran::parser::AccClauseList &accClauseList) {
+                const Fortran::parser::AccClauseList &accClauseList,
+                bool outerCombined = false) {
 
   // Parallel operation operands
   mlir::Value async;
@@ -1769,10 +1781,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
   Op computeOp;
   if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
     computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
-        builder, currentLocation, operands, operandSegments);
+        builder, currentLocation, eval, operands, operandSegments, outerCombined);
   else
     computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
-        builder, currentLocation, operands, operandSegments);
+        builder, currentLocation, eval, operands, operandSegments, outerCombined);
 
   if (addAsyncAttr)
     computeOp.setAsyncAttrAttr(builder.getUnitAttr());
@@ -1817,6 +1829,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
 static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
                          mlir::Location currentLocation,
+                         Fortran::lower::pft::Evaluation &eval,
                          Fortran::semantics::SemanticsContext &semanticsContext,
                          Fortran::lower::StatementContext &stmtCtx,
                          const Fortran::parser::AccClauseList &accClauseList) {
@@ -1942,7 +1955,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
     return;
 
   auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
-      builder, currentLocation, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments);
 
   dataOp.setAsyncAttr(addAsyncAttr);
   dataOp.setWaitAttr(addWaitAttr);
@@ -1971,6 +1984,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
 static void
 genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
                  mlir::Location currentLocation,
+                 Fortran::lower::pft::Evaluation &eval,
                  Fortran::semantics::SemanticsContext &semanticsContext,
                  Fortran::lower::StatementContext &stmtCtx,
                  const Fortran::parser::AccClauseList &accClauseList) {
@@ -2020,7 +2034,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
 
   auto hostDataOp =
       createRegionOp<mlir::acc::HostDataOp, mlir::acc::TerminatorOp>(
-          builder, currentLocation, operands, operandSegments);
+          builder, currentLocation, eval, operands, operandSegments);
 
   if (addIfPresentAttr)
     hostDataOp.setIfPresentAttr(builder.getUnitAttr());
@@ -2029,6 +2043,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
 static void
 genACC(Fortran::lower::AbstractConverter &converter,
        Fortran::semantics::SemanticsContext &semanticsContext,
+       Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
   const auto &beginBlockDirective =
       std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
@@ -2042,18 +2057,18 @@ genACC(Fortran::lower::AbstractConverter &converter,
 
   if (blockDirective.v == llvm::acc::ACCD_parallel) {
     createComputeOp<mlir::acc::ParallelOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_data) {
-    genACCDataOp(converter, currentLocation, semanticsContext, stmtCtx,
+    genACCDataOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_serial) {
     createComputeOp<mlir::acc::SerialOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_kernels) {
     createComputeOp<mlir::acc::KernelsOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_host_data) {
-    genACCHostDataOp(converter, currentLocation, semanticsContext, stmtCtx,
+    genACCHostDataOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                      accClauseList);
   }
 }
@@ -2061,6 +2076,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
 static void
 genACC(Fortran::lower::AbstractConverter &converter,
        Fortran::semantics::SemanticsContext &semanticsContext,
+       Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) {
   const auto &beginCombinedDirective =
       std::get<Fortran::parser::AccBeginCombinedDirective>(combinedConstruct.t);
@@ -2075,18 +2091,18 @@ genACC(Fortran::lower::AbstractConverter &converter,
 
   if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
     createComputeOp<mlir::acc::KernelsOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList, /*outerCombined=*/true);
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
     createComputeOp<mlir::acc::ParallelOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList, /*outerCombined=*/true);
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
     createComputeOp<mlir::acc::SerialOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList, /*outerCombined=*/true);
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else {
     llvm::report_fatal_error("Unknown combined construct encountered");
@@ -3169,14 +3185,14 @@ void Fortran::lower::genOpenACCConstruct(
   std::visit(
       common::visitors{
           [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
-            genACC(converter, semanticsContext, blockConstruct);
+            genACC(converter, semanticsContext, eval, blockConstruct);
           },
           [&](const Fortran::parser::OpenACCCombinedConstruct
                   &combinedConstruct) {
-            genACC(converter, semanticsContext, combinedConstruct);
+            genACC(converter, semanticsContext, eval, combinedConstruct);
           },
           [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
-            genACC(converter, semanticsContext, loopConstruct);
+            genACC(converter, semanticsContext, eval, loopConstruct);
           },
           [&](const Fortran::parser::OpenACCStandaloneConstruct
                   &standaloneConstruct) {
@@ -3274,3 +3290,12 @@ void Fortran::lower::attachDeclarePostDeallocAction(
                  /*preAlloc=*/{}, /*postAlloc=*/{}, /*preDealloc=*/{},
                  /*postDealloc=*/builder.getSymbolRefAttr(fctName.str())));
 }
+
+void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
+                                          mlir::Operation *op,
+                                          mlir::Location loc) {
+  if (mlir::isa<mlir::acc::ParallelOp, mlir::acc::LoopOp>(op))
+    builder.create<mlir::acc::YieldOp>(loc);
+  else
+    builder.create<mlir::acc::TerminatorOp>(loc);
+}
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index b960bb369dd4dd2..9f239f2921c00f9 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1981,29 +1981,6 @@ static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
   return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
 }
 
-/// Create empty blocks for the current region.
-/// These blocks replace blocks parented to an enclosing region.
-static void createEmptyRegionBlocks(
-    fir::FirOpBuilder &firOpBuilder,
-    std::list<Fortran::lower::pft::Evaluation> &evaluationList) {
-  mlir::Region *region = &firOpBuilder.getRegion();
-  for (Fortran::lower::pft::Evaluation &eval : evaluationList) {
-    if (eval.block) {
-      if (eval.block->empty()) {
-        eval.block->erase();
-        eval.block = firOpBuilder.createBlock(region);
-      } else {
-        [[maybe_unused]] mlir::Operation &terminatorOp = eval.block->back();
-        assert((mlir::isa<mlir::omp::TerminatorOp>(terminatorOp) ||
-                mlir::isa<mlir::omp::YieldOp>(terminatorOp)) &&
-               "expected terminator op");
-      }
-    }
-    if (!eval.isDirective() && eval.hasNestedEvaluations())
-      createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
-  }
-}
-
 static void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder,
                                   mlir::Operation *storeOp,
                                   mlir::Block &block) {
@@ -2092,7 +2069,7 @@ static void createBodyOfOp(
   // If it is an unstructured region and is not the outer region of a combined
   // construct, create empty blocks for all evaluations.
   if (eval.lowerAsUnstructured() && !outerCombined)
-    createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
+    Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp, mlir::omp::YieldOp>(firOpBuilder, eval.getNestedEvaluations());
 
   // Insert the terminator.
   if constexpr (std::is_same_v<Op, mlir::omp::WsLoopOp> ||
diff --git a/flang/test/Lower/OpenACC/acc-unstructured.f90 b/flang/test/Lower/OpenACC/acc-unstructured.f90
new file mode 100644
index 000000000000000..4ce474241a4bec7
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-unstructured.f90
@@ -0,0 +1,86 @@
+! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine test_unstructured1(a, b, c)
+  integer :: i, j, k
+  real :: a(:,:,:), b(:,:,:), c(:,:,:)
+
+  !$acc data copy(a, b, c)
+
+  !$acc kernels
+  a(:,:,:) = 0.0
+  !$acc end kernels
+
+  !$acc kernels
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+      end do
+    end do
+  end do
+  !$acc end kernels
+
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+      end do
+    end do
+
+    if (a(1,2,3) > 10) stop &#x27;just to be unstructured&#x27;
+  end do
+
+  !$acc end data
+
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_unstructured1
+! CHECK: acc.data
+! CHECK: acc.kernels
+! CHECK: acc.kernels
+! CHECK: fir.call @_FortranAStopStatementText
+
+
+subroutine test_unstructured2(a, b, c)
+  integer :: i, j, k
+  real :: a(:,:,:), b(:,:,:), c(:,:,:)
+
+  !$acc parallel loop
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+        if (a(1,2,3) > 10) stop &#x27;just to be unstructured&#x27;
+      end do
+    end do
+  end do
+
+! CHECK-LABEL: func.func @_QPtest_unstructured2
+! CHECK: acc.parallel
+! CHECK: acc.loop
+! CHECK: fir.call @_FortranAStopStatementText
+! CHECK: acc.yield
+! CHECK: acc.yield
+! CHECK: acc.yield
+
+end subroutine
+
+subroutine test_unstructured3(a, b, c)
+  integer :: i, j, k
+  real :: a(:,:,:), b(:,:,:), c(:,:,:)
+
+  !$acc parallel
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+        if (a(1,2,3) > 10) stop &#x27;just to be unstructured&#x27;
+      end do
+    end do
+  end do
+  !$acc end parallel
+
+! CHECK-LABEL: func.func @_QPtest_unstructured3
+! CHECK: acc.parallel
+! CHECK: fir.call @_FortranAStopStatementText
+! CHECK: acc.yield
+! CHECK: acc.yield
+
+end subroutine
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66284


More information about the flang-commits mailing list