[llvm-branch-commits] [flang] [mlir] [flang] Lower omp.workshare to other omp constructs (PR #101446)

Ivan R. Ivanov via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Aug 18 23:22:08 PDT 2024


https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/101446

>From d030f24253b929fe85f935c32b2ae3459897c4a8 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 14:42:35 +0900
Subject: [PATCH 01/23] Add verification and descriptions

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 10 +++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 14 +++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 42 +++++++++++++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 34 +++++++++------
 4 files changed, 87 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 0f29e911cb2f20..74497def8fd1ab 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -301,6 +301,10 @@ def WorkshareOp : OpenMP_Op<"workshare", traits = [
     block into separate units of work, and causes the threads of the team to
     share the work such that each unit is executed only once by one thread, in
     the context of its implicit task
+
+    This operation is used for the intermediate representation of the workshare
+    block before the work gets divided between the threads. See the flang
+    LowerWorkshare pass for details.
   }] # clausesDescription;
 
   let builders = [
@@ -313,10 +317,16 @@ def WorkshareLoopWrapperOp : OpenMP_Op<"workshare_loop_wrapper", traits = [
     RecursiveMemoryEffects, SingleBlock
   ], singleRegion = true> {
   let summary = "contains loop nests to be parallelized by workshare";
+  let description = [{
+    This operation wraps a loop nest that is marked for dividing into units of
+    work by an encompassing omp.workshare operation.
+  }];
+
   let builders = [
     OpBuilder<(ins), [{ build($_builder, $_state, {}); }]>
   ];
   let assemblyFormat = "$region attr-dict";
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6c1b77077bdbaf..90f9a19ebe32b5 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1692,6 +1692,20 @@ void WorkshareOp::build(OpBuilder &builder, OperationState &state,
   WorkshareOp::build(builder, state, clauses.nowait);
 }
 
+//===----------------------------------------------------------------------===//
+// WorkshareLoopWrapperOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkshareLoopWrapperOp::verify() {
+  if (!isWrapper())
+    return emitOpError() << "must be a loop wrapper";
+  if (getNestedWrapper())
+    return emitError() << "nested wrappers not supported";
+  if (!(*this)->getParentOfType<WorkshareOp>())
+    return emitError() << "must be nested in an omp.workshare";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 1d1d93f0977588..ee7c448c467cf5 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2383,3 +2383,45 @@ func.func @masked_arg_count_mismatch(%arg0: i32, %arg1: i32) {
     }) : (i32, i32) -> ()
   return
 }
+
+// -----
+func.func @nested_wrapper(%idx : index) {
+  omp.workshare {
+    // expected-error @below {{nested wrappers not supported}}
+    omp.workshare_loop_wrapper {
+      omp.simd {
+        omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+          omp.yield
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+func.func @not_wrapper() {
+  omp.workshare {
+    // expected-error @below {{must be a loop wrapper}}
+    omp.workshare_loop_wrapper {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+func.func @missing_workshare(%idx : index) {
+  // expected-error @below {{must be nested in an omp.workshare}}
+  omp.workshare_loop_wrapper {
+    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+      omp.yield
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 981e3fbb0306bf..24363334b4cb54 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2826,11 +2826,15 @@ func.func @omp_workshare_multiple_blocks() {
 
 // CHECK-LABEL: func @omp_workshare_loop_wrapper
 func.func @omp_workshare_loop_wrapper(%idx : index) {
-  // CHECK-NEXT: omp.workshare_loop_wrapper
-  omp.workshare_loop_wrapper {
-    // CHECK-NEXT: omp.loop_nest
-    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
-      omp.yield
+  // CHECK-NEXT: omp.workshare {
+  omp.workshare {
+    // CHECK-NEXT: omp.workshare_loop_wrapper
+    omp.workshare_loop_wrapper {
+      // CHECK-NEXT: omp.loop_nest
+      omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+        omp.yield
+      }
+      omp.terminator
     }
     omp.terminator
   }
@@ -2839,14 +2843,18 @@ func.func @omp_workshare_loop_wrapper(%idx : index) {
 
 // CHECK-LABEL: func @omp_workshare_loop_wrapper_attrs
 func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
-  // CHECK-NEXT: omp.workshare_loop_wrapper {
-  omp.workshare_loop_wrapper {
-    // CHECK-NEXT: omp.loop_nest
-    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
-      omp.yield
-    }
+  // CHECK-NEXT: omp.workshare {
+  omp.workshare {
+    // CHECK-NEXT: omp.workshare_loop_wrapper {
+    omp.workshare_loop_wrapper {
+      // CHECK-NEXT: omp.loop_nest
+      omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+        omp.yield
+      }
+      omp.terminator
+    // CHECK: } {attr_in_dict}
+    } {attr_in_dict}
     omp.terminator
-  // CHECK: } {attr_in_dict}
-  } {attr_in_dict}
+  }
   return
 }

>From 69e85558ec78c7a74be9168b4227bb262f67225e Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:11:47 +0900
Subject: [PATCH 02/23] [flang][omp] Emit omp.workshare in frontend

---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 30 ++++++++++++++++++++++++++----
 1 file changed, 26 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 2b1839b5270d4f..f7bc565ea8cbc1 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1270,6 +1270,15 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter,
       loc, llvm::omp::Directive::OMPD_taskwait);
 }
 
+static void genWorkshareClauses(lower::AbstractConverter &converter,
+                                semantics::SemanticsContext &semaCtx,
+                                lower::StatementContext &stmtCtx,
+                                const List<Clause> &clauses, mlir::Location loc,
+                                mlir::omp::WorkshareOperands &clauseOps) {
+  ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processNowait(clauseOps);
+}
+
 static void genTeamsClauses(lower::AbstractConverter &converter,
                             semantics::SemanticsContext &semaCtx,
                             lower::StatementContext &stmtCtx,
@@ -1890,6 +1899,22 @@ genTaskyieldOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   return converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(loc);
 }
 
+static mlir::omp::WorkshareOp
+genWorkshareOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+           semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+           mlir::Location loc, const ConstructQueue &queue,
+           ConstructQueue::iterator item) {
+  lower::StatementContext stmtCtx;
+  mlir::omp::WorkshareOperands clauseOps;
+  genWorkshareClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
+
+  return genOpWithBody<mlir::omp::WorkshareOp>(
+      OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                        llvm::omp::Directive::OMPD_workshare)
+          .setClauses(&item->clauses),
+      queue, item, clauseOps);
+}
+
 static mlir::omp::TeamsOp
 genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
            semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
@@ -2249,10 +2274,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                   llvm::omp::getOpenMPDirectiveName(dir) + ")");
   // case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_workshare:
-    // FIXME: Workshare is not a commonly used OpenMP construct, an
-    // implementation for this feature will come later. For the codes
-    // that use this construct, add a single construct for now.
-    genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genWorkshareOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
 
   // Composite constructs

>From d17c55249b5d7065721225e62f10fef5e3da4f51 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:02:37 +0900
Subject: [PATCH 03/23] Fix lower test for workshare

---
 flang/test/Lower/OpenMP/workshare.f90 | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/flang/test/Lower/OpenMP/workshare.f90 b/flang/test/Lower/OpenMP/workshare.f90
index 1e11677a15e1f0..8e771952f5b6da 100644
--- a/flang/test/Lower/OpenMP/workshare.f90
+++ b/flang/test/Lower/OpenMP/workshare.f90
@@ -6,7 +6,7 @@ subroutine sb1(arr)
   integer :: arr(:)
 !CHECK: omp.parallel  {
   !$omp parallel
-!CHECK: omp.single  {
+!CHECK: omp.workshare {
   !$omp workshare
     arr = 0
   !$omp end workshare
@@ -20,7 +20,7 @@ subroutine sb2(arr)
   integer :: arr(:)
 !CHECK: omp.parallel  {
   !$omp parallel
-!CHECK: omp.single nowait {
+!CHECK: omp.workshare nowait {
   !$omp workshare
     arr = 0
   !$omp end workshare nowait
@@ -33,7 +33,7 @@ subroutine sb2(arr)
 subroutine sb3(arr)
   integer :: arr(:)
 !CHECK: omp.parallel  {
-!CHECK: omp.single  {
+!CHECK: omp.workshare  {
   !$omp parallel workshare
     arr = 0
   !$omp end parallel workshare

>From 38c581c8defc81105160a69bb46a9e489b56f10e Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:12:34 +0900
Subject: [PATCH 04/23] [flang] Introduce ws loop nest generation for HLFIR
 lowering

---
 .../flang/Optimizer/Builder/HLFIRTools.h      | 12 +++--
 flang/lib/Lower/ConvertCall.cpp               |  2 +-
 flang/lib/Lower/OpenMP/ReductionProcessor.cpp |  4 +-
 flang/lib/Optimizer/Builder/HLFIRTools.cpp    | 52 ++++++++++++++-----
 .../HLFIR/Transforms/BufferizeHLFIR.cpp       |  3 +-
 .../LowerHLFIROrderedAssignments.cpp          | 30 +++++------
 .../Transforms/OptimizedBufferization.cpp     |  6 +--
 7 files changed, 69 insertions(+), 40 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 6b41025eea0780..14e42c6f358e46 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -357,8 +357,8 @@ hlfir::ElementalOp genElementalOp(
 
 /// Structure to describe a loop nest.
 struct LoopNest {
-  fir::DoLoopOp outerLoop;
-  fir::DoLoopOp innerLoop;
+  mlir::Operation *outerOp;
+  mlir::Block *body;
   llvm::SmallVector<mlir::Value> oneBasedIndices;
 };
 
@@ -366,11 +366,13 @@ struct LoopNest {
 /// \p isUnordered specifies whether the loops in the loop nest
 /// are unordered.
 LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
-                     mlir::ValueRange extents, bool isUnordered = false);
+                     mlir::ValueRange extents, bool isUnordered = false,
+                     bool emitWsLoop = false);
 inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
-                            mlir::Value shape, bool isUnordered = false) {
+                            mlir::Value shape, bool isUnordered = false,
+                            bool emitWsLoop = false) {
   return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
-                     isUnordered);
+                     isUnordered, emitWsLoop);
 }
 
 /// Inline the body of an hlfir.elemental at the current insertion point
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index fd873f55dd844e..0689d6e033dd9c 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -2128,7 +2128,7 @@ class ElementalCallBuilder {
           hlfir::genLoopNest(loc, builder, shape, !mustBeOrdered);
       mlir::ValueRange oneBasedIndices = loopNest.oneBasedIndices;
       auto insPt = builder.saveInsertionPoint();
-      builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+      builder.setInsertionPointToStart(loopNest.body);
       callContext.stmtCtx.pushScope();
       for (auto &preparedActual : loweredActuals)
         if (preparedActual)
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c3c1f363033c27..72a90dd0d6f29d 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -375,7 +375,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
   // know this won't miss any opportuinties for clever elemental inlining
   hlfir::LoopNest nest = hlfir::genLoopNest(
       loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
-  builder.setInsertionPointToStart(nest.innerLoop.getBody());
+  builder.setInsertionPointToStart(nest.body);
   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
@@ -389,7 +389,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
       builder, loc, redId, refTy, lhsEle, rhsEle);
   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
 
-  builder.setInsertionPointAfter(nest.outerLoop);
+  builder.setInsertionPointAfter(nest.outerOp);
   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
 }
 
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 8d0ae2f195178c..cd07cb741eb4bb 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <optional>
 
 // Return explicit extents. If the base is a fir.box, this won't read it to
@@ -855,26 +856,51 @@ mlir::Value hlfir::inlineElementalOp(
 
 hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
                                    fir::FirOpBuilder &builder,
-                                   mlir::ValueRange extents, bool isUnordered) {
+                                   mlir::ValueRange extents, bool isUnordered,
+                                   bool emitWsLoop) {
   hlfir::LoopNest loopNest;
   assert(!extents.empty() && "must have at least one extent");
-  auto insPt = builder.saveInsertionPoint();
+  mlir::OpBuilder::InsertionGuard guard(builder);
   loopNest.oneBasedIndices.assign(extents.size(), mlir::Value{});
   // Build loop nest from column to row.
   auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
   mlir::Type indexType = builder.getIndexType();
-  unsigned dim = extents.size() - 1;
-  for (auto extent : llvm::reverse(extents)) {
-    auto ub = builder.createConvert(loc, indexType, extent);
-    loopNest.innerLoop =
-        builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
-    builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
-    // Reverse the indices so they are in column-major order.
-    loopNest.oneBasedIndices[dim--] = loopNest.innerLoop.getInductionVar();
-    if (!loopNest.outerLoop)
-      loopNest.outerLoop = loopNest.innerLoop;
+  if (emitWsLoop) {
+    auto wsloop = builder.create<mlir::omp::WsloopOp>(
+        loc, mlir::ArrayRef<mlir::NamedAttribute>());
+    loopNest.outerOp = wsloop;
+    builder.createBlock(&wsloop.getRegion());
+    mlir::omp::LoopNestOperands lnops;
+    lnops.loopInclusive = builder.getUnitAttr();
+    for (auto extent : llvm::reverse(extents)) {
+      lnops.loopLowerBounds.push_back(one);
+      lnops.loopUpperBounds.push_back(extent);
+      lnops.loopSteps.push_back(one);
+    }
+    auto lnOp = builder.create<mlir::omp::LoopNestOp>(loc, lnops);
+    builder.create<mlir::omp::TerminatorOp>(loc);
+    mlir::Block *block = builder.createBlock(&lnOp.getRegion());
+    for (auto extent : llvm::reverse(extents))
+      block->addArgument(extent.getType(), extent.getLoc());
+    loopNest.body = block;
+    builder.create<mlir::omp::YieldOp>(loc);
+    for (unsigned dim = 0; dim < extents.size(); dim++)
+      loopNest.oneBasedIndices[extents.size() - dim - 1] =
+          lnOp.getRegion().front().getArgument(dim);
+  } else {
+    unsigned dim = extents.size() - 1;
+    for (auto extent : llvm::reverse(extents)) {
+      auto ub = builder.createConvert(loc, indexType, extent);
+      auto doLoop =
+          builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
+      loopNest.body = doLoop.getBody();
+      builder.setInsertionPointToStart(loopNest.body);
+      // Reverse the indices so they are in column-major order.
+      loopNest.oneBasedIndices[dim--] = doLoop.getInductionVar();
+      if (!loopNest.outerOp)
+        loopNest.outerOp = doLoop;
+    }
   }
-  builder.restoreInsertionPoint(insPt);
   return loopNest;
 }
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index a70a6b388c4b1a..b608677c526310 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace hlfir {
@@ -793,7 +794,7 @@ struct ElementalOpConversion
     hlfir::LoopNest loopNest =
         hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
     auto insPt = builder.saveInsertionPoint();
-    builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+    builder.setInsertionPointToStart(loopNest.body);
     auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
                                           loopNest.oneBasedIndices);
     hlfir::Entity elementValue(yield.getElementValue());
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
index 85dd517cb57914..645abf65d10a32 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
@@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
       // if the LHS is not).
       mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
       elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
-      builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
+      builder.setInsertionPointToStart(elementalLoopNest->body);
       lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
                                       elementalLoopNest->oneBasedIndices);
       rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
@@ -484,7 +484,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
     for (auto &cleanupConversion : argConversionCleanups)
       cleanupConversion();
     if (elementalLoopNest)
-      builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
+      builder.setInsertionPointAfter(elementalLoopNest->outerOp);
   } else {
     // TODO: preserve allocatable assignment aspects for forall once
     // they are conveyed in hlfir.region_assign.
@@ -493,7 +493,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
   generateCleanupIfAny(loweredLhs.elementalCleanup);
   if (loweredLhs.vectorSubscriptLoopNest)
     builder.setInsertionPointAfter(
-        loweredLhs.vectorSubscriptLoopNest->outerLoop);
+        loweredLhs.vectorSubscriptLoopNest->outerOp);
   generateCleanupIfAny(oldRhsYield);
   generateCleanupIfAny(loweredLhs.nonElementalCleanup);
 }
@@ -518,8 +518,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
       hlfir::Entity savedMask{maybeSaved->first};
       mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
       whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
-      constructStack.push_back(whereLoopNest->outerLoop.getOperation());
-      builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
+      constructStack.push_back(whereLoopNest->outerOp);
+      builder.setInsertionPointToStart(whereLoopNest->body);
       mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
                                             whereLoopNest->oneBasedIndices);
       generateMaskIfOp(cdt);
@@ -527,7 +527,7 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
         // If this is the same run as the one that saved the value, the clean-up
         // was left-over to be done now.
         auto insertionPoint = builder.saveInsertionPoint();
-        builder.setInsertionPointAfter(whereLoopNest->outerLoop);
+        builder.setInsertionPointAfter(whereLoopNest->outerOp);
         generateCleanupIfAny(maybeSaved->second);
         builder.restoreInsertionPoint(insertionPoint);
       }
@@ -539,8 +539,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
     mask.generateNoneElementalPart(builder, mapper);
     mlir::Value shape = mask.generateShape(builder, mapper);
     whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
-    constructStack.push_back(whereLoopNest->outerLoop.getOperation());
-    builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
+    constructStack.push_back(whereLoopNest->outerOp);
+    builder.setInsertionPointToStart(whereLoopNest->body);
     mlir::Value cdt = generateMaskedEntity(mask);
     generateMaskIfOp(cdt);
     return;
@@ -754,7 +754,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
       loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
           loc, builder, loweredLhs.vectorSubscriptShape.value());
       builder.setInsertionPointToStart(
-          loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
+          loweredLhs.vectorSubscriptLoopNest->body);
     }
     loweredLhs.lhs = temp->second.fetch(loc, builder);
     return loweredLhs;
@@ -772,7 +772,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
         hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
                            !elementalAddrLhs.isOrdered());
     builder.setInsertionPointToStart(
-        loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
+        loweredLhs.vectorSubscriptLoopNest->body);
     mapper.map(elementalAddrLhs.getIndices(),
                loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
     for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
@@ -798,11 +798,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
   if (!maskedExpr.noneElementalPartWasGenerated) {
     // Generate none elemental part before the where loops (but inside the
     // current forall loops if any).
-    builder.setInsertionPoint(whereLoopNest->outerLoop);
+    builder.setInsertionPoint(whereLoopNest->outerOp);
     maskedExpr.generateNoneElementalPart(builder, mapper);
   }
   // Generate the none elemental part cleanup after the where loops.
-  builder.setInsertionPointAfter(whereLoopNest->outerLoop);
+  builder.setInsertionPointAfter(whereLoopNest->outerOp);
   maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
   // Generate the value of the current element for the masked expression
   // at the current insertion point (inside the where loops, and any fir.if
@@ -1242,7 +1242,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
   LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
   fir::factory::TemporaryStorage *temp = nullptr;
   if (loweredLhs.vectorSubscriptLoopNest)
-    constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
+    constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
   if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
     // Vector subscripted entity for which the shape must also be saved on top
     // of the element addresses (e.g. the shape may change in each forall
@@ -1265,7 +1265,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
     // subscripted LHS.
     auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
     auto insertionPoint = builder.saveInsertionPoint();
-    builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
+    builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
     vectorTmp.pushShape(loc, builder, shape);
     builder.restoreInsertionPoint(insertionPoint);
   } else {
@@ -1291,7 +1291,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
   if (loweredLhs.vectorSubscriptLoopNest) {
     constructStack.pop_back();
     builder.setInsertionPointAfter(
-        loweredLhs.vectorSubscriptLoopNest->outerLoop);
+        loweredLhs.vectorSubscriptLoopNest->outerOp);
   }
 }
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index c5b809514c54c6..c4aed6b79df923 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -483,7 +483,7 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
   // hlfir.elemental region inside the inner loop
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
                                         loopNest.oneBasedIndices);
   hlfir::Entity elementValue{yield.getElementValue()};
@@ -554,7 +554,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
       hlfir::getIndexExtents(loc, builder, shape);
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto arrayElement =
       hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
   builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
@@ -649,7 +649,7 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
       hlfir::getIndexExtents(loc, builder, shape);
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto rhsArrayElement =
       hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
   rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);

>From 638ffd43b1b469aec7308c1edd9b431dcfb9c2fd Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:08:34 +0900
Subject: [PATCH 05/23] Emit loop nests in a custom wrapper

---
 flang/include/flang/Optimizer/Builder/HLFIRTools.h |  6 +++---
 flang/lib/Optimizer/Builder/HLFIRTools.cpp         | 11 +++++------
 2 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 14e42c6f358e46..69874719572186 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -367,12 +367,12 @@ struct LoopNest {
 /// are unordered.
 LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
                      mlir::ValueRange extents, bool isUnordered = false,
-                     bool emitWsLoop = false);
+                     bool emitWorkshareLoop = false);
 inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
                             mlir::Value shape, bool isUnordered = false,
-                            bool emitWsLoop = false) {
+                            bool emitWorkshareLoop = false) {
   return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
-                     isUnordered, emitWsLoop);
+                     isUnordered, emitWorkshareLoop);
 }
 
 /// Inline the body of an hlfir.elemental at the current insertion point
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index cd07cb741eb4bb..91b1b3d774a012 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -857,7 +857,7 @@ mlir::Value hlfir::inlineElementalOp(
 hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
                                    fir::FirOpBuilder &builder,
                                    mlir::ValueRange extents, bool isUnordered,
-                                   bool emitWsLoop) {
+                                   bool emitWorkshareLoop) {
   hlfir::LoopNest loopNest;
   assert(!extents.empty() && "must have at least one extent");
   mlir::OpBuilder::InsertionGuard guard(builder);
@@ -865,11 +865,10 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
   // Build loop nest from column to row.
   auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
   mlir::Type indexType = builder.getIndexType();
-  if (emitWsLoop) {
-    auto wsloop = builder.create<mlir::omp::WsloopOp>(
-        loc, mlir::ArrayRef<mlir::NamedAttribute>());
-    loopNest.outerOp = wsloop;
-    builder.createBlock(&wsloop.getRegion());
+  if (emitWorkshareLoop) {
+    auto wslw = builder.create<mlir::omp::WorkshareLoopWrapperOp>(loc);
+    loopNest.outerOp = wslw;
+    builder.createBlock(&wslw.getRegion());
     mlir::omp::LoopNestOperands lnops;
     lnops.loopInclusive = builder.getUnitAttr();
     for (auto extent : llvm::reverse(extents)) {

>From a9f4b777dfe0e0ef3a44df94487c0f76ab20cccc Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 22:05:47 +0900
Subject: [PATCH 06/23] Only emit unordered loops as omp loops

---
 flang/lib/Optimizer/Builder/HLFIRTools.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 91b1b3d774a012..333331378841ed 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -858,6 +858,7 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
                                    fir::FirOpBuilder &builder,
                                    mlir::ValueRange extents, bool isUnordered,
                                    bool emitWorkshareLoop) {
+  emitWorkshareLoop = emitWorkshareLoop && isUnordered;
   hlfir::LoopNest loopNest;
   assert(!extents.empty() && "must have at least one extent");
   mlir::OpBuilder::InsertionGuard guard(builder);

>From 30e12940ef4b7c11d3ff749fc4af7d3e3075a6ed Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 22:06:55 +0900
Subject: [PATCH 07/23] [flang] Lower omp.workshare to other omp constructs

---
 flang/include/flang/Optimizer/OpenMP/Passes.h |   2 +
 .../include/flang/Optimizer/OpenMP/Passes.td  |   6 +-
 flang/include/flang/Tools/CLOptions.inc       |   1 +
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   1 +
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 259 ++++++++++++++++++
 .../Transforms/OpenMP/lower-workshare.mlir    |  81 ++++++
 .../Transforms/OpenMP/lower-workshare5.mlir   |  42 +++
 7 files changed, 391 insertions(+), 1 deletion(-)
 create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare.mlir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare5.mlir

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.h b/flang/include/flang/Optimizer/OpenMP/Passes.h
index 403d79667bf448..11fa4e59f891ea 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.h
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.h
@@ -25,6 +25,8 @@ namespace flangomp {
 #define GEN_PASS_REGISTRATION
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
 
+bool shouldUseWorkshareLowering(mlir::Operation *op);
+
 } // namespace flangomp
 
 #endif // FORTRAN_OPTIMIZER_OPENMP_PASSES_H
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index a192b9a33c1976..1c9d75d8cfaa18 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -1,4 +1,4 @@
-//===-- Passes.td - HLFIR pass definition file -------------*- tablegen -*-===//
+//===-- Passes.td - flang OpenMP pass definition -----------*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -37,4 +37,8 @@ def FunctionFiltering : Pass<"omp-function-filtering"> {
   ];
 }
 
+def LowerWorkshare : Pass<"lower-workshare"> {
+  let summary = "Lower workshare construct";
+}
+
 #endif //FORTRAN_OPTIMIZER_OPENMP_PASSES
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 05b2f31711add2..a565effebfa92e 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,6 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
+  pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index a8984d256b8f6a..2a38f157a851ce 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -4,6 +4,7 @@ add_flang_library(FlangOpenMPTransforms
   FunctionFiltering.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
+  LowerWorkshare.cpp
 
   DEPENDS
   FIRDialect
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
new file mode 100644
index 00000000000000..40975552d1fe33
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -0,0 +1,259 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Lower omp workshare construct.
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/iterator_range.h"
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKSHARE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workshare"
+
+using namespace mlir;
+
+namespace flangomp {
+bool shouldUseWorkshareLowering(Operation *op) {
+  auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
+  if (!workshare)
+    return false;
+  return workshare->getParentOfType<omp::ParallelOp>();
+}
+} // namespace flangomp
+
+namespace {
+
+struct SingleRegion {
+  Block::iterator begin, end;
+};
+
+static bool isSupportedByFirAlloca(Type ty) {
+  return !isa<fir::ReferenceType>(ty);
+}
+
+static bool isSafeToParallelize(Operation *op) {
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface) {
+    return false;
+  }
+  interface.getEffects(effects);
+  if (effects.empty())
+    return true;
+
+  return false;
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+///   %a = fir.allocmem
+///   omp.wsloop {}
+///   fir.call Assign %b %a
+///   fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// omp.single {
+///   %a = fir.allocmem
+///   fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.wsloop {}
+/// omp.single {
+///   fir.call Assign %b %a_reloaded
+///   fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which need
+/// to be accessed in all threads in the closest omp.parallel
+///
+/// TODO currently we need to be able to access the encompassing omp.parallel so
+/// that we can allocate temporaries accessible by all threads outside of it.
+/// In case we do not find it, we fall back to converting the omp.workshare to
+/// omp.single.
+/// To better handle this we should probably enable yielding values out of an
+/// omp.single which will be supported by the omp runtime.
+void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+  assert(wsOp.getRegion().getBlocks().size() == 1);
+
+  Location loc = wsOp->getLoc();
+
+  omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
+  if (!parallelOp) {
+    wsOp.emitWarning("cannot handle workshare, converting to single");
+    Operation *terminator = wsOp.getRegion().front().getTerminator();
+    wsOp->getBlock()->getOperations().splice(
+        wsOp->getIterator(), wsOp.getRegion().front().getOperations());
+    terminator->erase();
+    return;
+  }
+
+  OpBuilder allocBuilder(parallelOp);
+  OpBuilder rootBuilder(wsOp);
+  IRMapping rootMapping;
+
+  omp::SingleOp singleOp = nullptr;
+
+  auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
+                              IRMapping singleMapping) {
+    if (auto reloaded = rootMapping.lookupOrNull(v))
+      return;
+    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
+    Type ty = v.getType();
+    Value alloc, reloaded;
+    if (isSupportedByFirAlloca(ty)) {
+      alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
+      singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+      reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
+    } else {
+      auto one = allocBuilder.create<LLVM::ConstantOp>(
+          loc, allocBuilder.getI32Type(), 1);
+      alloc =
+          allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
+      Value toStore = singleBuilder
+                          .create<UnrealizedConversionCastOp>(
+                              loc, llvmPtrTy, singleMapping.lookup(v))
+                          .getResult(0);
+      singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
+      reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
+      reloaded =
+          rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
+              .getResult(0);
+    }
+    rootMapping.map(v, reloaded);
+  };
+
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
+    IRMapping singleMapping = rootMapping;
+
+    for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
+      singleBuilder.clone(op, singleMapping);
+      if (isSafeToParallelize(&op)) {
+        rootBuilder.clone(op, rootMapping);
+      } else {
+        // Prepare reloaded values for results of operations that cannot be
+        // safely parallelized and which are used after the region `sr`
+        for (auto res : op.getResults()) {
+          for (auto &use : res.getUses()) {
+            Operation *user = use.getOwner();
+            while (user->getParentOp() != wsOp)
+              user = user->getParentOp();
+            if (!user->isBeforeInBlock(&*sr.end)) {
+              // We need to reload
+              mapReloadedValue(use.get(), singleBuilder, singleMapping);
+            }
+          }
+        }
+      }
+    }
+    singleBuilder.create<omp::TerminatorOp>(loc);
+  };
+
+  Block *wsBlock = &wsOp.getRegion().front();
+  assert(wsBlock->getTerminator()->getNumOperands() == 0);
+  Operation *terminator = wsBlock->getTerminator();
+
+  SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
+
+  auto it = wsBlock->begin();
+  auto getSingleRegion = [&]() {
+    if (&*it == terminator)
+      return false;
+    if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
+      regions.push_back(pop);
+      it++;
+      return true;
+    }
+    SingleRegion sr;
+    sr.begin = it;
+    while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
+      it++;
+    sr.end = it;
+    assert(sr.begin != sr.end);
+    regions.push_back(sr);
+    return true;
+  };
+  while (getSingleRegion())
+    ;
+
+  for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
+    bool isLast = i + 1 == regions.size();
+    if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
+      omp::SingleOperands singleOperands;
+      if (isLast)
+        singleOperands.nowait = rootBuilder.getUnitAttr();
+      singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+      OpBuilder singleBuilder(singleOp);
+      singleBuilder.createBlock(&singleOp.getRegion());
+      moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
+    } else {
+      rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
+      if (!isLast)
+        rootBuilder.create<omp::BarrierOp>(loc);
+    }
+  }
+
+  if (!wsOp.getNowait())
+    rootBuilder.create<omp::BarrierOp>(loc);
+
+  wsOp->erase();
+
+  return;
+}
+
+class LowerWorksharePass
+    : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
+public:
+  void runOnOperation() override {
+    SmallPtrSet<Operation *, 8> parents;
+    getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
+      Operation *isolatedParent =
+          wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
+      parents.insert(isolatedParent);
+
+      lowerWorkshare(wsOp);
+    });
+
+    // Do folding
+    for (Operation *isolatedParent : parents) {
+      RewritePatternSet patterns(&getContext());
+      GreedyRewriteConfig config;
+      // prevent the pattern driver form merging blocks
+      config.enableRegionSimplification =
+          mlir::GreedySimplifyRegionLevel::Disabled;
+      if (failed(applyPatternsAndFoldGreedily(isolatedParent,
+                                              std::move(patterns), config))) {
+        emitError(isolatedParent->getLoc(), "error in lower workshare\n");
+        signalPassFailure();
+      }
+    }
+  }
+};
+} // namespace
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
new file mode 100644
index 00000000000000..a8d36443f08bda
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -0,0 +1,81 @@
+// RUN: fir-opt --lower-workshare %s | FileCheck %s
+
+module {
+// CHECK-LABEL:   func.func @simple(
+// CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
+// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           omp.parallel {
+// CHECK:             omp.single {
+// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]]#0 : !fir.ref<!fir.array<42xi32>> to !llvm.ptr
+// CHECK:               llvm.store %[[VAL_9]], %[[VAL_5]] : !llvm.ptr, !llvm.ptr
+// CHECK:               %[[VAL_10:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               fir.store %[[VAL_11]]#0 to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_12:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.ptr
+// CHECK:             %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.ptr to !fir.ref<!fir.array<42xi32>>
+// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_2]] : i32
+// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_14]] (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                 omp.yield
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.barrier
+// CHECK:             omp.single nowait {
+// CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.barrier
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+  func.func @simple(%arg0: !fir.ref<!fir.array<42xi32>>) {
+    omp.parallel {
+      omp.workshare {
+        %c42 = arith.constant 42 : index
+        %c1_i32 = arith.constant 1 : i32
+        %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+        %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+        %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+        %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+        %true = arith.constant true
+        %c1 = arith.constant 1 : index
+        omp.wsloop {
+          omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+            %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+            %8 = fir.load %7 : !fir.ref<i32>
+            %9 = arith.subi %8, %c1_i32 : i32
+            %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+            hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+            omp.yield
+          }
+          omp.terminator
+        }
+        %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+        fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+        omp.terminator
+      }
+      omp.terminator
+    }
+    return
+  }
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workshare5.mlir b/flang/test/Transforms/OpenMP/lower-workshare5.mlir
new file mode 100644
index 00000000000000..177f8aa8f86c7c
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare5.mlir
@@ -0,0 +1,42 @@
+// XFAIL: *
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// TODO we can lower these but we have no guarantee that the parent of
+// omp.workshare supports multi-block regions, thus we fail for now.
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+    ^bb1:
+      %c1 = arith.constant 1 : i32
+      cf.br ^bb3(%c1: i32)
+    ^bb3(%arg1: i32):
+      "test.test2"(%arg1) : (i32) -> ()
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+    ^bb1:
+      %c1 = arith.constant 1 : i32
+      cf.br ^bb3(%c1: i32)
+    ^bb2:
+      "test.test2"(%r) : (i32) -> ()
+      omp.terminator
+    ^bb3(%arg1: i32):
+      %r = "test.test2"(%arg1) : (i32) -> i32
+      cf.br ^bb2
+    }
+    omp.terminator
+  }
+  return
+}

>From cbb14172cb5b72719081cd623f0608a2b6186950 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:41:09 +0900
Subject: [PATCH 08/23] Change to workshare loop wrapper op

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 24 ++++++++++++-------
 .../Transforms/OpenMP/lower-workshare.mlir    |  5 ++--
 2 files changed, 18 insertions(+), 11 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 40975552d1fe33..cb342b60de4e8d 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -21,6 +21,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/iterator_range.h"
 
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
 #include <variant>
 
 namespace flangomp {
@@ -73,7 +74,7 @@ static bool isSafeToParallelize(Operation *op) {
 ///
 /// omp.workshare {
 ///   %a = fir.allocmem
-///   omp.wsloop {}
+///   omp.workshare_loop_wrapper {}
 ///   fir.call Assign %b %a
 ///   fir.freemem %a
 /// }
@@ -85,7 +86,7 @@ static bool isSafeToParallelize(Operation *op) {
 ///   fir.store %a %tmp
 /// }
 /// %a_reloaded = fir.load %tmp
-/// omp.wsloop {}
+/// omp.workshare_loop_wrapper {}
 /// omp.single {
 ///   fir.call Assign %b %a_reloaded
 ///   fir.freemem %a_reloaded
@@ -180,20 +181,20 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
   assert(wsBlock->getTerminator()->getNumOperands() == 0);
   Operation *terminator = wsBlock->getTerminator();
 
-  SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
+  SmallVector<std::variant<SingleRegion, omp::WorkshareLoopWrapperOp>> regions;
 
   auto it = wsBlock->begin();
   auto getSingleRegion = [&]() {
     if (&*it == terminator)
       return false;
-    if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
+    if (auto pop = dyn_cast<omp::WorkshareLoopWrapperOp>(&*it)) {
       regions.push_back(pop);
       it++;
       return true;
     }
     SingleRegion sr;
     sr.begin = it;
-    while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
+    while (&*it != terminator && !isa<omp::WorkshareLoopWrapperOp>(&*it))
       it++;
     sr.end = it;
     assert(sr.begin != sr.end);
@@ -214,9 +215,16 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
       singleBuilder.createBlock(&singleOp.getRegion());
       moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
     } else {
-      rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
-      if (!isLast)
-        rootBuilder.create<omp::BarrierOp>(loc);
+      omp::WsloopOperands wsloopOperands;
+      if (isLast)
+        wsloopOperands.nowait = rootBuilder.getUnitAttr();
+      auto wsloop =
+          rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
+      auto wslw = std::get<omp::WorkshareLoopWrapperOp>(loopOrSingle);
+      auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
+          rootBuilder.clone(*wslw, rootMapping));
+      wsloop.getRegion().takeBody(clonedWslw.getRegion());
+      clonedWslw->erase();
     }
   }
 
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index a8d36443f08bda..cb5791d35916a9 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -34,7 +34,6 @@ module {
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             omp.barrier
 // CHECK:             omp.single nowait {
 // CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
 // CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
@@ -56,7 +55,7 @@ module {
         %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
         %true = arith.constant true
         %c1 = arith.constant 1 : index
-        omp.wsloop {
+        "omp.workshare_loop_wrapper"() ({
           omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
             %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
             %8 = fir.load %7 : !fir.ref<i32>
@@ -66,7 +65,7 @@ module {
             omp.yield
           }
           omp.terminator
-        }
+        }) : () -> ()
         %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
         %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
         %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>

>From 5f4f9bf1267cff9f9f413bd3b667661aab75b9b3 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:47:27 +0900
Subject: [PATCH 09/23] Move single op declaration

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index cb342b60de4e8d..2322d2acbc0138 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -120,8 +120,6 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
   OpBuilder rootBuilder(wsOp);
   IRMapping rootMapping;
 
-  omp::SingleOp singleOp = nullptr;
-
   auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
                               IRMapping singleMapping) {
     if (auto reloaded = rootMapping.lookupOrNull(v))
@@ -210,7 +208,8 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
       omp::SingleOperands singleOperands;
       if (isLast)
         singleOperands.nowait = rootBuilder.getUnitAttr();
-      singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+      omp::SingleOp singleOp =
+          rootBuilder.create<omp::SingleOp>(loc, singleOperands);
       OpBuilder singleBuilder(singleOp);
       singleBuilder.createBlock(&singleOp.getRegion());
       moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);

>From 79ae18461e1909637f206d10ac286d8dcb5ef18c Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 17:13:58 +0900
Subject: [PATCH 10/23] Schedule pass properly

---
 flang/include/flang/Tools/CLOptions.inc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index a565effebfa92e..b256dba3c92949 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,7 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  pm.addPass(flangomp::createLowerWorkshare());
+  addNestedPassToAllTopLevelOperations(pm, flangomp::createLowerWorkshare);
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed

>From 5911275afa122ebc77ef5f653799cb19ee5af187 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 00:30:40 +0900
Subject: [PATCH 11/23] Correctly handle nested nested loop nests to be
 parallelized by workshare

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 256 ++++++++++--------
 1 file changed, 138 insertions(+), 118 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 2322d2acbc0138..8e79d1401c01c6 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -19,9 +19,14 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/iterator_range.h"
 
+#include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/Visitors.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <variant>
 
 namespace flangomp {
@@ -52,90 +57,40 @@ static bool isSupportedByFirAlloca(Type ty) {
   return !isa<fir::ReferenceType>(ty);
 }
 
-static bool isSafeToParallelize(Operation *op) {
-  if (isa<fir::DeclareOp>(op))
-    return true;
-
-  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
-  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
-  if (!interface) {
-    return false;
-  }
-  interface.getEffects(effects);
-  if (effects.empty())
-    return true;
-
-  return false;
+static bool mustParallelizeOp(Operation *op) {
+  return op
+      ->walk(
+          [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
+      .wasInterrupted();
 }
 
-/// Lowers workshare to a sequence of single-thread regions and parallel loops
-///
-/// For example:
-///
-/// omp.workshare {
-///   %a = fir.allocmem
-///   omp.workshare_loop_wrapper {}
-///   fir.call Assign %b %a
-///   fir.freemem %a
-/// }
-///
-/// becomes
-///
-/// omp.single {
-///   %a = fir.allocmem
-///   fir.store %a %tmp
-/// }
-/// %a_reloaded = fir.load %tmp
-/// omp.workshare_loop_wrapper {}
-/// omp.single {
-///   fir.call Assign %b %a_reloaded
-///   fir.freemem %a_reloaded
-/// }
-///
-/// Note that we allocate temporary memory for values in omp.single's which need
-/// to be accessed in all threads in the closest omp.parallel
-///
-/// TODO currently we need to be able to access the encompassing omp.parallel so
-/// that we can allocate temporaries accessible by all threads outside of it.
-/// In case we do not find it, we fall back to converting the omp.workshare to
-/// omp.single.
-/// To better handle this we should probably enable yielding values out of an
-/// omp.single which will be supported by the omp runtime.
-void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
-  assert(wsOp.getRegion().getBlocks().size() == 1);
-
-  Location loc = wsOp->getLoc();
+static bool isSafeToParallelize(Operation *op) {
+  return isa<fir::DeclareOp>(op) || isPure(op);
+}
 
-  omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
-  if (!parallelOp) {
-    wsOp.emitWarning("cannot handle workshare, converting to single");
-    Operation *terminator = wsOp.getRegion().front().getTerminator();
-    wsOp->getBlock()->getOperations().splice(
-        wsOp->getIterator(), wsOp.getRegion().front().getOperations());
-    terminator->erase();
-    return;
-  }
-
-  OpBuilder allocBuilder(parallelOp);
-  OpBuilder rootBuilder(wsOp);
-  IRMapping rootMapping;
+static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
+                              IRMapping &rootMapping, Location loc) {
+  Operation *parentOp = sourceRegion.getParentOp();
+  OpBuilder rootBuilder(sourceRegion.getContext());
 
+  // TODO need to copyprivate the alloca's
   auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
                               IRMapping singleMapping) {
+    OpBuilder allocaBuilder(&targetRegion.front().front());
     if (auto reloaded = rootMapping.lookupOrNull(v))
       return;
-    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
+    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
     Value alloc, reloaded;
     if (isSupportedByFirAlloca(ty)) {
-      alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
+      alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
       singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
       reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
     } else {
-      auto one = allocBuilder.create<LLVM::ConstantOp>(
-          loc, allocBuilder.getI32Type(), 1);
+      auto one = allocaBuilder.create<LLVM::ConstantOp>(
+          loc, allocaBuilder.getI32Type(), 1);
       alloc =
-          allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
+          allocaBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
       Value toStore = singleBuilder
                           .create<UnrealizedConversionCastOp>(
                               loc, llvmPtrTy, singleMapping.lookup(v))
@@ -162,9 +117,10 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
         for (auto res : op.getResults()) {
           for (auto &use : res.getUses()) {
             Operation *user = use.getOwner();
-            while (user->getParentOp() != wsOp)
+            while (user->getParentOp() != parentOp)
               user = user->getParentOp();
-            if (!user->isBeforeInBlock(&*sr.end)) {
+            if (!(user->isBeforeInBlock(&*sr.end) &&
+                  sr.begin->isBeforeInBlock(user))) {
               // We need to reload
               mapReloadedValue(use.get(), singleBuilder, singleMapping);
             }
@@ -175,61 +131,125 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
     singleBuilder.create<omp::TerminatorOp>(loc);
   };
 
-  Block *wsBlock = &wsOp.getRegion().front();
-  assert(wsBlock->getTerminator()->getNumOperands() == 0);
-  Operation *terminator = wsBlock->getTerminator();
+  // TODO Need to handle these (clone them) in dominator tree order
+  for (Block &block : sourceRegion) {
+    rootBuilder.createBlock(
+        &targetRegion, {}, block.getArgumentTypes(),
+        llvm::map_to_vector(block.getArguments(),
+                            [](BlockArgument arg) { return arg.getLoc(); }));
+    Operation *terminator = block.getTerminator();
 
-  SmallVector<std::variant<SingleRegion, omp::WorkshareLoopWrapperOp>> regions;
+    SmallVector<std::variant<SingleRegion, Operation *>> regions;
 
-  auto it = wsBlock->begin();
-  auto getSingleRegion = [&]() {
-    if (&*it == terminator)
-      return false;
-    if (auto pop = dyn_cast<omp::WorkshareLoopWrapperOp>(&*it)) {
-      regions.push_back(pop);
-      it++;
+    auto it = block.begin();
+    auto getOneRegion = [&]() {
+      if (&*it == terminator)
+        return false;
+      if (mustParallelizeOp(&*it)) {
+        regions.push_back(&*it);
+        it++;
+        return true;
+      }
+      SingleRegion sr;
+      sr.begin = it;
+      while (&*it != terminator && !mustParallelizeOp(&*it))
+        it++;
+      sr.end = it;
+      assert(sr.begin != sr.end);
+      regions.push_back(sr);
       return true;
+    };
+    while (getOneRegion())
+      ;
+
+    for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
+      bool isLast = i + 1 == regions.size();
+      if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+        omp::SingleOperands singleOperands;
+        if (isLast)
+          singleOperands.nowait = rootBuilder.getUnitAttr();
+        omp::SingleOp singleOp =
+            rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+        OpBuilder singleBuilder(singleOp);
+        singleBuilder.createBlock(&singleOp.getRegion());
+        moveToSingle(std::get<SingleRegion>(opOrSingle), singleBuilder);
+      } else {
+        auto op = std::get<Operation *>(opOrSingle);
+        if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {
+          omp::WsloopOperands wsloopOperands;
+          if (isLast)
+            wsloopOperands.nowait = rootBuilder.getUnitAttr();
+          auto wsloop =
+              rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
+          auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
+              rootBuilder.clone(*wslw, rootMapping));
+          wsloop.getRegion().takeBody(clonedWslw.getRegion());
+          clonedWslw->erase();
+        } else {
+          assert(mustParallelizeOp(op));
+          Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
+          for (auto [region, clonedRegion] :
+               llvm::zip(op->getRegions(), cloned->getRegions()))
+            parallelizeRegion(region, clonedRegion, rootMapping, loc);
+        }
+      }
     }
-    SingleRegion sr;
-    sr.begin = it;
-    while (&*it != terminator && !isa<omp::WorkshareLoopWrapperOp>(&*it))
-      it++;
-    sr.end = it;
-    assert(sr.begin != sr.end);
-    regions.push_back(sr);
-    return true;
-  };
-  while (getSingleRegion())
-    ;
-
-  for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
-    bool isLast = i + 1 == regions.size();
-    if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
-      omp::SingleOperands singleOperands;
-      if (isLast)
-        singleOperands.nowait = rootBuilder.getUnitAttr();
-      omp::SingleOp singleOp =
-          rootBuilder.create<omp::SingleOp>(loc, singleOperands);
-      OpBuilder singleBuilder(singleOp);
-      singleBuilder.createBlock(&singleOp.getRegion());
-      moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
-    } else {
-      omp::WsloopOperands wsloopOperands;
-      if (isLast)
-        wsloopOperands.nowait = rootBuilder.getUnitAttr();
-      auto wsloop =
-          rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
-      auto wslw = std::get<omp::WorkshareLoopWrapperOp>(loopOrSingle);
-      auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
-          rootBuilder.clone(*wslw, rootMapping));
-      wsloop.getRegion().takeBody(clonedWslw.getRegion());
-      clonedWslw->erase();
-    }
+
+    rootBuilder.clone(*block.getTerminator(), rootMapping);
   }
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+///   %a = fir.allocmem
+///   omp.workshare_loop_wrapper {}
+///   fir.call Assign %b %a
+///   fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// omp.single {
+///   %a = fir.allocmem
+///   fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.workshare_loop_wrapper {}
+/// omp.single {
+///   fir.call Assign %b %a_reloaded
+///   fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which need
+/// to be accessed in all threads in the closest omp.parallel
+void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+  Location loc = wsOp->getLoc();
+  IRMapping rootMapping;
+
+  OpBuilder rootBuilder(wsOp);
+
+  // TODO We need something like an scf;execute here, but that is not registered
+  // so using fir.if for now but it looks like it does not support multiple
+  // blocks so it doesnt work for multi block case...
+  auto ifOp = rootBuilder.create<fir::IfOp>(
+      loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
+  ifOp.getThenRegion().front().erase();
+
+  parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);
+
+  Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
+  assert(isa<omp::TerminatorOp>(terminatorOp));
+  OpBuilder termBuilder(terminatorOp);
 
   if (!wsOp.getNowait())
-    rootBuilder.create<omp::BarrierOp>(loc);
+    termBuilder.create<omp::BarrierOp>(loc);
+
+  termBuilder.create<fir::ResultOp>(loc, ValueRange());
 
+  terminatorOp->erase();
   wsOp->erase();
 
   return;

>From 3604b3dd80e1e1b8347c66cc1d24e2ed72c998d9 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 00:33:57 +0900
Subject: [PATCH 12/23] Leave comments for shouldUseWorkshareLowering

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 21 +++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 8e79d1401c01c6..40dae0fd848ef8 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -40,10 +40,23 @@ using namespace mlir;
 
 namespace flangomp {
 bool shouldUseWorkshareLowering(Operation *op) {
-  auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
-  if (!workshare)
-    return false;
-  return workshare->getParentOfType<omp::ParallelOp>();
+  // TODO this is insufficient, as we could have
+  // omp.parallel {
+  //   omp.workshare {
+  //     omp.parallel {
+  //       hlfir.elemental {}
+  //
+  // Then this hlfir.elemental shall _not_ use the lowering for workshare
+  //
+  // Standard says:
+  //   For a parallel construct, the construct is a unit of work with respect to
+  //   the workshare construct. The statements contained in the parallel
+  //   construct are executed by a new thread team.
+  //
+  // TODO similarly for single, critical, etc. Need to think through the
+  // patterns and implement this function.
+  //
+  return op->getParentOfType<omp::WorkshareOp>();
 }
 } // namespace flangomp
 

>From ea26f32c964cf59ced7d735785c40f4050269e1f Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 13:14:38 +0900
Subject: [PATCH 13/23] Use copyprivate to scatter val from omp.single

TODO still need to implement copy function
TODO transitive check for usage outside of omp.single not imiplemented yet
---
 .../include/flang/Optimizer/OpenMP/Passes.td  |   3 +-
 flang/include/flang/Tools/CLOptions.inc       |   2 +-
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 138 ++++++++++++++----
 3 files changed, 109 insertions(+), 34 deletions(-)

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 1c9d75d8cfaa18..041240cad12eb3 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -37,7 +37,8 @@ def FunctionFiltering : Pass<"omp-function-filtering"> {
   ];
 }
 
-def LowerWorkshare : Pass<"lower-workshare"> {
+// Needs to be scheduled on Module as we create functions in it
+def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index b256dba3c92949..a565effebfa92e 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,7 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  addNestedPassToAllTopLevelOperations(pm, flangomp::createLowerWorkshare);
+  pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 40dae0fd848ef8..950737fccada79 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -8,25 +8,27 @@
 // Lower omp workshare construct.
 //===----------------------------------------------------------------------===//
 
-#include "flang/Optimizer/Dialect/FIROps.h"
-#include "flang/Optimizer/Dialect/FIRType.h"
-#include "flang/Optimizer/OpenMP/Passes.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVectorExtras.h"
-#include "llvm/ADT/iterator_range.h"
-
+#include <flang/Optimizer/Builder/FIRBuilder.h>
+#include <flang/Optimizer/Dialect/FIROps.h>
+#include <flang/Optimizer/Dialect/FIRType.h>
+#include <flang/Optimizer/HLFIR/HLFIROps.h>
+#include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/ADT/SmallVectorExtras.h>
+#include <llvm/ADT/iterator_range.h>
+#include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
-#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/PatternMatch.h>
 #include <mlir/IR/Visitors.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
+
 #include <variant>
 
 namespace flangomp {
@@ -71,6 +73,8 @@ static bool isSupportedByFirAlloca(Type ty) {
 }
 
 static bool mustParallelizeOp(Operation *op) {
+  // TODO as in shouldUseWorkshareLowering we be careful not to pick up
+  // workshare_loop_wrapper in nested omp.parallel ops
   return op
       ->walk(
           [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
@@ -78,7 +82,33 @@ static bool mustParallelizeOp(Operation *op) {
 }
 
 static bool isSafeToParallelize(Operation *op) {
-  return isa<fir::DeclareOp>(op) || isPure(op);
+  return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
+         isMemoryEffectFree(op);
+}
+
+static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
+                                         fir::FirOpBuilder builder) {
+  mlir::ModuleOp module = builder.getModule();
+  mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
+
+  std::string copyFuncName =
+      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+
+  if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
+    return decl;
+  // create function
+  mlir::OpBuilder::InsertionGuard guard(builder);
+  mlir::OpBuilder modBuilder(module.getBodyRegion());
+  llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
+  auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
+  mlir::func::FuncOp funcOp =
+      modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
+  funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
+  builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
+                      {loc, loc});
+  builder.setInsertionPointToStart(&funcOp.getRegion().back());
+  builder.create<mlir::func::ReturnOp>(loc);
+  return funcOp;
 }
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
@@ -86,19 +116,23 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
   Operation *parentOp = sourceRegion.getParentOp();
   OpBuilder rootBuilder(sourceRegion.getContext());
 
+  ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
+  OpBuilder copyFuncBuilder(m.getBodyRegion());
+  fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
+
   // TODO need to copyprivate the alloca's
-  auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
-                              IRMapping singleMapping) {
-    OpBuilder allocaBuilder(&targetRegion.front().front());
+  auto mapReloadedValue =
+      [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
+          OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
     if (auto reloaded = rootMapping.lookupOrNull(v))
-      return;
+      return nullptr;
     Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
     Value alloc, reloaded;
     if (isSupportedByFirAlloca(ty)) {
       alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
       singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
-      reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
+      reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
     } else {
       auto one = allocaBuilder.create<LLVM::ConstantOp>(
           loc, allocaBuilder.getI32Type(), 1);
@@ -109,21 +143,25 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               loc, llvmPtrTy, singleMapping.lookup(v))
                           .getResult(0);
       singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
-      reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
+      reloaded = parallelBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
       reloaded =
-          rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
+          parallelBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
               .getResult(0);
     }
     rootMapping.map(v, reloaded);
+    return alloc;
   };
 
-  auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
+                          OpBuilder singleBuilder,
+                          OpBuilder parallelBuilder) -> SmallVector<Value> {
     IRMapping singleMapping = rootMapping;
+    SmallVector<Value> copyPrivate;
 
     for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
       singleBuilder.clone(op, singleMapping);
       if (isSafeToParallelize(&op)) {
-        rootBuilder.clone(op, rootMapping);
+        parallelBuilder.clone(op, rootMapping);
       } else {
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
@@ -132,16 +170,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
             Operation *user = use.getOwner();
             while (user->getParentOp() != parentOp)
               user = user->getParentOp();
-            if (!(user->isBeforeInBlock(&*sr.end) &&
-                  sr.begin->isBeforeInBlock(user))) {
-              // We need to reload
-              mapReloadedValue(use.get(), singleBuilder, singleMapping);
+            // TODO we need to look at transitively used vals
+            if (true || !(user->isBeforeInBlock(&*sr.end) &&
+                          sr.begin->isBeforeInBlock(user))) {
+              auto alloc =
+                  mapReloadedValue(use.get(), allocaBuilder, singleBuilder,
+                                   parallelBuilder, singleMapping);
+              if (alloc)
+                copyPrivate.push_back(alloc);
             }
           }
         }
       }
     }
     singleBuilder.create<omp::TerminatorOp>(loc);
+    return copyPrivate;
   };
 
   // TODO Need to handle these (clone them) in dominator tree order
@@ -178,14 +221,45 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
       bool isLast = i + 1 == regions.size();
       if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+        OpBuilder singleBuilder(sourceRegion.getContext());
+        Block *singleBlock = new Block();
+        singleBuilder.setInsertionPointToStart(singleBlock);
+
+        OpBuilder allocaBuilder(sourceRegion.getContext());
+        Block *allocaBlock = new Block();
+        allocaBuilder.setInsertionPointToStart(allocaBlock);
+
+        OpBuilder parallelBuilder(sourceRegion.getContext());
+        Block *parallelBlock = new Block();
+        parallelBuilder.setInsertionPointToStart(parallelBlock);
+
         omp::SingleOperands singleOperands;
         if (isLast)
           singleOperands.nowait = rootBuilder.getUnitAttr();
+        auto insPtAtSingle = rootBuilder.saveInsertionPoint();
+        singleOperands.copyprivateVars =
+            moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
+                         singleBuilder, parallelBuilder);
+        for (auto var : singleOperands.copyprivateVars) {
+          Type ty;
+          if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
+            ty = firAlloca.getAllocatedType();
+          } else {
+            llvm_unreachable("unexpected");
+          }
+          mlir::func::FuncOp funcOp =
+              createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
+          singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
+        }
         omp::SingleOp singleOp =
             rootBuilder.create<omp::SingleOp>(loc, singleOperands);
-        OpBuilder singleBuilder(singleOp);
-        singleBuilder.createBlock(&singleOp.getRegion());
-        moveToSingle(std::get<SingleRegion>(opOrSingle), singleBuilder);
+        singleOp.getRegion().push_back(singleBlock);
+        rootBuilder.getInsertionBlock()->getOperations().splice(
+            rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
+        targetRegion.front().getOperations().splice(
+            singleOp->getIterator(), allocaBlock->getOperations());
+        delete allocaBlock;
+        delete parallelBlock;
       } else {
         auto op = std::get<Operation *>(opOrSingle);
         if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {

>From add3b79ee577266b38b4b01a9d3dfe7d00ee6fa6 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 13:39:49 +0900
Subject: [PATCH 14/23] Transitively check for users outisde of single op

TODO need to implement copy func
TODO need to hoist allocas outside of single regions
---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 51 ++++++++++++++-----
 1 file changed, 37 insertions(+), 14 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 950737fccada79..2e88d852ff2cba 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -111,6 +111,38 @@ static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
   return funcOp;
 }
 
+static bool isUserOutsideSR(Operation *user, Operation *parentOp,
+                            SingleRegion sr) {
+  while (user->getParentOp() != parentOp)
+    user = user->getParentOp();
+  return sr.begin->getBlock() != user->getBlock() ||
+         !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
+}
+
+static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
+  Block *srBlock = sr.begin->getBlock();
+  Operation *parentOp = srBlock->getParentOp();
+
+  for (auto &use : v.getUses()) {
+    Operation *user = use.getOwner();
+    if (isUserOutsideSR(user, parentOp, sr))
+      return true;
+
+    // Results of nested users cannot be used outside of the SR
+    if (user->getBlock() != srBlock)
+      continue;
+
+    // A non-safe to parallelize operation will be handled separately
+    if (!isSafeToParallelize(user))
+      continue;
+
+    for (auto res : user->getResults())
+      if (isTransitivelyUsedOutside(res, sr))
+        return true;
+  }
+  return false;
+}
+
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
   Operation *parentOp = sourceRegion.getParentOp();
@@ -166,19 +198,11 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
         for (auto res : op.getResults()) {
-          for (auto &use : res.getUses()) {
-            Operation *user = use.getOwner();
-            while (user->getParentOp() != parentOp)
-              user = user->getParentOp();
-            // TODO we need to look at transitively used vals
-            if (true || !(user->isBeforeInBlock(&*sr.end) &&
-                          sr.begin->isBeforeInBlock(user))) {
-              auto alloc =
-                  mapReloadedValue(use.get(), allocaBuilder, singleBuilder,
-                                   parallelBuilder, singleMapping);
-              if (alloc)
-                copyPrivate.push_back(alloc);
-            }
+          if (isTransitivelyUsedOutside(res, sr)) {
+            auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
+                                          parallelBuilder, singleMapping);
+            if (alloc)
+              copyPrivate.push_back(alloc);
           }
         }
       }
@@ -236,7 +260,6 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         omp::SingleOperands singleOperands;
         if (isLast)
           singleOperands.nowait = rootBuilder.getUnitAttr();
-        auto insPtAtSingle = rootBuilder.saveInsertionPoint();
         singleOperands.copyprivateVars =
             moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
                          singleBuilder, parallelBuilder);

>From 92c1e7787f772333fbca37f1b2c83b3f1fb6b257 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 14:30:43 +0900
Subject: [PATCH 15/23] Add tests

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp |  25 +-
 .../Transforms/OpenMP/lower-workshare.mlir    | 230 +++++++++++++-----
 2 files changed, 188 insertions(+), 67 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 2e88d852ff2cba..30af2556cf4cae 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -18,6 +18,7 @@
 #include <llvm/ADT/iterator_range.h>
 #include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
 #include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
 #include <mlir/IR/BuiltinOps.h>
@@ -75,6 +76,14 @@ static bool isSupportedByFirAlloca(Type ty) {
 static bool mustParallelizeOp(Operation *op) {
   // TODO as in shouldUseWorkshareLowering we be careful not to pick up
   // workshare_loop_wrapper in nested omp.parallel ops
+  //
+  // e.g.
+  //
+  // omp.parallel {
+  //   omp.workshare {
+  //     omp.parallel {
+  //       omp.workshare {
+  //         omp.workshare_loop_wrapper {}
   return op
       ->walk(
           [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
@@ -89,10 +98,14 @@ static bool isSafeToParallelize(Operation *op) {
 static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
                                          fir::FirOpBuilder builder) {
   mlir::ModuleOp module = builder.getModule();
-  mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
-
-  std::string copyFuncName =
-      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+  std::string copyFuncName;
+  if (auto rt = dyn_cast<fir::ReferenceType>(varType)) {
+    mlir::Type eleTy = rt.getEleTy();
+    copyFuncName =
+        fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+  } else {
+    copyFuncName = "_workshare_copy_llvm_ptr";
+  }
 
   if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
     return decl;
@@ -145,9 +158,7 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
-  Operation *parentOp = sourceRegion.getParentOp();
   OpBuilder rootBuilder(sourceRegion.getContext());
-
   ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
   OpBuilder copyFuncBuilder(m.getBodyRegion());
   fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
@@ -268,7 +279,7 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
             ty = firAlloca.getAllocatedType();
           } else {
-            llvm_unreachable("unexpected");
+            ty = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
           }
           mlir::func::FuncOp funcOp =
               createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index cb5791d35916a9..19123e71cacf60 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -1,80 +1,190 @@
-// RUN: fir-opt --lower-workshare %s | FileCheck %s
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
-module {
-// CHECK-LABEL:   func.func @simple(
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.parallel {
+    omp.workshare {
+      %c42 = arith.constant 42 : index
+      %c1_i32 = arith.constant 1 : i32
+      %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+      %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+      %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+      %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+      %true = arith.constant true
+      %c1 = arith.constant 1 : index
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+          %8 = fir.load %7 : !fir.ref<i32>
+          %9 = arith.subi %8, %c1_i32 : i32
+          %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+          hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+      %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+      fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+
+// -----
+
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.workshare {
+    %c1_i32 = arith.constant 1 : i32
+    %alloc = fir.alloca i32
+    fir.store %c1_i32 to %alloc : !fir.ref<i32>
+    %c42 = arith.constant 42 : index
+    %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+    %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+    %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+    %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+    %true = arith.constant true
+    %c1 = arith.constant 1 : index
+    "omp.workshare_loop_wrapper"() ({
+      omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+        %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+        %8 = fir.load %7 : !fir.ref<i32>
+        %ld = fir.load %alloc : !fir.ref<i32>
+        %n8 = arith.subi %8, %ld : i32
+        %9 = arith.subi %n8, %c1_i32 : i32
+        %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+        hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+        omp.yield
+      }
+      omp.terminator
+    }) : () -> ()
+    %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    "test.test1"(%alloc) : (!fir.ref<i32>) -> ()
+    hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+    fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+    omp.terminator
+  }
+  return
+}
+
+
+// CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
 // CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK:           %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
-// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           %[[VAL_4:.*]] = arith.constant true
 // CHECK:           omp.parallel {
-// CHECK:             omp.single {
-// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]]#0 : !fir.ref<!fir.array<42xi32>> to !llvm.ptr
-// CHECK:               llvm.store %[[VAL_9]], %[[VAL_5]] : !llvm.ptr, !llvm.ptr
-// CHECK:               %[[VAL_10:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               fir.store %[[VAL_11]]#0 to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             fir.if %[[VAL_4]] {
+// CHECK:               %[[VAL_5:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:                 %[[VAL_6:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:                 %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_6]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:                 %[[VAL_8:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:                 fir.store %[[VAL_8]] to %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:                 %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_6]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               omp.wsloop {
+// CHECK:                 omp.loop_nest (%[[VAL_14:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                   %[[VAL_15:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_14]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                   %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
+// CHECK:                   %[[VAL_17:.*]] = arith.subi %[[VAL_16]], %[[VAL_2]] : i32
+// CHECK:                   %[[VAL_18:.*]] = hlfir.designate %[[VAL_13]]#0 (%[[VAL_14]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                   hlfir.assign %[[VAL_17]] to %[[VAL_18]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                   omp.yield
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.single nowait {
+// CHECK:                 hlfir.assign %[[VAL_13]]#0 to %[[VAL_11]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:                 fir.freemem %[[VAL_13]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.barrier
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_workshare_copy_llvm_ptr(
+// CHECK-SAME:                                                %[[VAL_0:.*]]: !llvm.ptr,
+// CHECK-SAME:                                                %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @wsfunc(
+// CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_5:.*]] = arith.constant true
+// CHECK:           fir.if %[[VAL_5]] {
+// CHECK:             %[[VAL_6:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
+// CHECK:             %[[VAL_7:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_6]] -> @_workshare_copy_llvm_ptr : !llvm.ptr, %[[VAL_7]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               %[[VAL_8:.*]] = fir.alloca i32
+// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !fir.ref<i32> to !llvm.ptr
+// CHECK:               llvm.store %[[VAL_9]], %[[VAL_6]] : !llvm.ptr, !llvm.ptr
+// CHECK:               fir.store %[[VAL_3]] to %[[VAL_8]] : !fir.ref<i32>
+// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_12:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_12]] to %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             %[[VAL_12:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.ptr
-// CHECK:             %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.ptr to !fir.ref<!fir.array<42xi32>>
-// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_14:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr -> !llvm.ptr
+// CHECK:             %[[VAL_15:.*]] = builtin.unrealized_conversion_cast %[[VAL_14]] : !llvm.ptr to !fir.ref<i32>
+// CHECK:             %[[VAL_16:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_17:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_16]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_18]](%[[VAL_16]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_2]] : i32
-// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_14]] (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.loop_nest (%[[VAL_20:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_17]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_23:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_24:.*]] = arith.subi %[[VAL_22]], %[[VAL_23]] : i32
+// CHECK:                 %[[VAL_25:.*]] = arith.subi %[[VAL_24]], %[[VAL_3]] : i32
+// CHECK:                 %[[VAL_26:.*]] = hlfir.designate %[[VAL_19]]#0 (%[[VAL_20]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_25]] to %[[VAL_26]] temporary_lhs : i32, !fir.ref<i32>
 // CHECK:                 omp.yield
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.single nowait {
-// CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
+// CHECK:               "test.test1"(%[[VAL_15]]) : (!fir.ref<i32>) -> ()
+// CHECK:               hlfir.assign %[[VAL_19]]#0 to %[[VAL_17]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_19]]#0 : !fir.heap<!fir.array<42xi32>>
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.barrier
-// CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
-  func.func @simple(%arg0: !fir.ref<!fir.array<42xi32>>) {
-    omp.parallel {
-      omp.workshare {
-        %c42 = arith.constant 42 : index
-        %c1_i32 = arith.constant 1 : i32
-        %0 = fir.shape %c42 : (index) -> !fir.shape<1>
-        %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-        %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-        %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-        %true = arith.constant true
-        %c1 = arith.constant 1 : index
-        "omp.workshare_loop_wrapper"() ({
-          omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
-            %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-            %8 = fir.load %7 : !fir.ref<i32>
-            %9 = arith.subi %8, %c1_i32 : i32
-            %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-            hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
-            omp.yield
-          }
-          omp.terminator
-        }) : () -> ()
-        %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-        fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
-        omp.terminator
-      }
-      omp.terminator
-    }
-    return
-  }
-}
+

>From 6fe32271a39c9381c7de68704e9d939388c5dfff Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 15:50:58 +0900
Subject: [PATCH 16/23] Hoist allocas

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 10 ++-
 .../Transforms/OpenMP/lower-workshare.mlir    | 69 +++++++++----------
 2 files changed, 41 insertions(+), 38 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 30af2556cf4cae..d0cd235d3eb079 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -163,7 +163,6 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
   OpBuilder copyFuncBuilder(m.getBodyRegion());
   fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
 
-  // TODO need to copyprivate the alloca's
   auto mapReloadedValue =
       [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
           OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
@@ -202,10 +201,17 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     SmallVector<Value> copyPrivate;
 
     for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
-      singleBuilder.clone(op, singleMapping);
       if (isSafeToParallelize(&op)) {
+        singleBuilder.clone(op, singleMapping);
         parallelBuilder.clone(op, rootMapping);
+      } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
+        auto hoisted =
+            cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
+        rootMapping.map(&*alloca, &*hoisted);
+        rootMapping.map(alloca.getResult(), hoisted.getResult());
+        copyPrivate.push_back(hoisted);
       } else {
+        singleBuilder.clone(op, singleMapping);
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
         for (auto res : op.getResults()) {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 19123e71cacf60..b78cfd80e17acb 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -1,5 +1,7 @@
 // RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
+// checks:
+// nowait on final omp.single
 func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   omp.parallel {
     omp.workshare {
@@ -37,6 +39,8 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // -----
 
+// checks:
+// fir.alloca hoisted out and copyprivate'd
 func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   omp.workshare {
     %c1_i32 = arith.constant 1 : i32
@@ -73,7 +77,6 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   return
 }
 
-
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
@@ -130,9 +133,9 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           return
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_workshare_copy_llvm_ptr(
-// CHECK-SAME:                                                %[[VAL_0:.*]]: !llvm.ptr,
-// CHECK-SAME:                                                %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK-LABEL:   func.func private @_workshare_copy_i32(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: !fir.ref<i32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: !fir.ref<i32>) {
 // CHECK:           return
 // CHECK:         }
 
@@ -141,46 +144,40 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK:           %[[VAL_5:.*]] = arith.constant true
-// CHECK:           fir.if %[[VAL_5]] {
-// CHECK:             %[[VAL_6:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
-// CHECK:             %[[VAL_7:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:             omp.single copyprivate(%[[VAL_6]] -> @_workshare_copy_llvm_ptr : !llvm.ptr, %[[VAL_7]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:               %[[VAL_8:.*]] = fir.alloca i32
-// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !fir.ref<i32> to !llvm.ptr
-// CHECK:               llvm.store %[[VAL_9]], %[[VAL_6]] : !llvm.ptr, !llvm.ptr
-// CHECK:               fir.store %[[VAL_3]] to %[[VAL_8]] : !fir.ref<i32>
-// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_12:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               fir.store %[[VAL_12]] to %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_4:.*]] = arith.constant true
+// CHECK:           fir.if %[[VAL_4]] {
+// CHECK:             %[[VAL_5:.*]] = fir.alloca i32
+// CHECK:             %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_6]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i32>
+// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_9:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_9]] to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             %[[VAL_14:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr -> !llvm.ptr
-// CHECK:             %[[VAL_15:.*]] = builtin.unrealized_conversion_cast %[[VAL_14]] : !llvm.ptr to !fir.ref<i32>
-// CHECK:             %[[VAL_16:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:             %[[VAL_17:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_16]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:             %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_18]](%[[VAL_16]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_11:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_20:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_17]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_23:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_24:.*]] = arith.subi %[[VAL_22]], %[[VAL_23]] : i32
-// CHECK:                 %[[VAL_25:.*]] = arith.subi %[[VAL_24]], %[[VAL_3]] : i32
-// CHECK:                 %[[VAL_26:.*]] = hlfir.designate %[[VAL_19]]#0 (%[[VAL_20]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_25]] to %[[VAL_26]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = fir.load %[[VAL_5]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_19:.*]] = arith.subi %[[VAL_17]], %[[VAL_18]] : i32
+// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_3]] : i32
+// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_20]] to %[[VAL_21]] temporary_lhs : i32, !fir.ref<i32>
 // CHECK:                 omp.yield
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.single nowait {
-// CHECK:               "test.test1"(%[[VAL_15]]) : (!fir.ref<i32>) -> ()
-// CHECK:               hlfir.assign %[[VAL_19]]#0 to %[[VAL_17]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_19]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:               "test.test1"(%[[VAL_5]]) : (!fir.ref<i32>) -> ()
+// CHECK:               hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.barrier

>From a9d8578f5bc00c199338666866126dbfe86d3089 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 15:52:23 +0900
Subject: [PATCH 17/23] More tests

---
 .../Transforms/OpenMP/lower-workshare2.mlir   | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare2.mlir

diff --git a/flang/test/Transforms/OpenMP/lower-workshare2.mlir b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
new file mode 100644
index 00000000000000..325a40d4184453
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
@@ -0,0 +1,21 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @nonowait
+func.func @nonowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  // CHECK: omp.barrier
+  omp.workshare {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @nowait
+func.func @nowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  // CHECK-NOT: omp.barrier
+  omp.workshare nowait {
+    omp.terminator
+  }
+  return
+}

>From 8395991ea4e1e034c9d107641c6b4119008aa403 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:13:39 +0900
Subject: [PATCH 18/23] Emit body for copy func

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 44 +++++--------------
 .../Transforms/OpenMP/lower-workshare.mlir    |  6 +++
 2 files changed, 17 insertions(+), 33 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index d0cd235d3eb079..20f45296a8159a 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -69,10 +69,6 @@ struct SingleRegion {
   Block::iterator begin, end;
 };
 
-static bool isSupportedByFirAlloca(Type ty) {
-  return !isa<fir::ReferenceType>(ty);
-}
-
 static bool mustParallelizeOp(Operation *op) {
   // TODO as in shouldUseWorkshareLowering we be careful not to pick up
   // workshare_loop_wrapper in nested omp.parallel ops
@@ -98,14 +94,10 @@ static bool isSafeToParallelize(Operation *op) {
 static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
                                          fir::FirOpBuilder builder) {
   mlir::ModuleOp module = builder.getModule();
-  std::string copyFuncName;
-  if (auto rt = dyn_cast<fir::ReferenceType>(varType)) {
-    mlir::Type eleTy = rt.getEleTy();
-    copyFuncName =
-        fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
-  } else {
-    copyFuncName = "_workshare_copy_llvm_ptr";
-  }
+  auto rt = cast<fir::ReferenceType>(varType);
+  mlir::Type eleTy = rt.getEleTy();
+  std::string copyFuncName =
+      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
 
   if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
     return decl;
@@ -120,6 +112,10 @@ static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
   builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
                       {loc, loc});
   builder.setInsertionPointToStart(&funcOp.getRegion().back());
+
+  Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(0));
+  builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(1));
+
   builder.create<mlir::func::ReturnOp>(loc);
   return funcOp;
 }
@@ -168,28 +164,10 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
     if (auto reloaded = rootMapping.lookupOrNull(v))
       return nullptr;
-    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
-    Value alloc, reloaded;
-    if (isSupportedByFirAlloca(ty)) {
-      alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
-      singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
-      reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
-    } else {
-      auto one = allocaBuilder.create<LLVM::ConstantOp>(
-          loc, allocaBuilder.getI32Type(), 1);
-      alloc =
-          allocaBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
-      Value toStore = singleBuilder
-                          .create<UnrealizedConversionCastOp>(
-                              loc, llvmPtrTy, singleMapping.lookup(v))
-                          .getResult(0);
-      singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
-      reloaded = parallelBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
-      reloaded =
-          parallelBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
-              .getResult(0);
-    }
+    Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
+    singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+    Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
     rootMapping.map(v, reloaded);
     return alloc;
   };
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index b78cfd80e17acb..997bc8d79f9b3f 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -80,6 +80,8 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
 // CHECK:           return
 // CHECK:         }
 
@@ -130,12 +132,16 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
 // CHECK:           return
 // CHECK:         }
 
 // CHECK-LABEL:   func.func private @_workshare_copy_i32(
 // CHECK-SAME:                                           %[[VAL_0:.*]]: !fir.ref<i32>,
 // CHECK-SAME:                                           %[[VAL_1:.*]]: !fir.ref<i32>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<i32>
 // CHECK:           return
 // CHECK:         }
 

>From 8bc89ab0ed9dd3607fe4bfb325418122febede09 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:59:34 +0900
Subject: [PATCH 19/23] Test the tmp storing logic

---
 .../Transforms/OpenMP/lower-workshare.mlir    |  2 -
 .../Transforms/OpenMP/lower-workshare3.mlir   | 74 +++++++++++++++++++
 2 files changed, 74 insertions(+), 2 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare3.mlir

diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 997bc8d79f9b3f..063d3865065e01 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -36,7 +36,6 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   return
 }
 
-
 // -----
 
 // checks:
@@ -190,4 +189,3 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
-
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
new file mode 100644
index 00000000000000..84eded94503282
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -0,0 +1,74 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+
+// tests if the correct values are stored
+
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.parallel {
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK-NOT: fir.alloca
+    omp.workshare {
+
+      %t1 = "test.test1"() : () -> i32
+      // CHECK: %[[T1:.*]] = "test.test1"
+      // CHECK: fir.store %[[T1]]
+      %t2 = "test.test2"() : () -> i32
+      // CHECK: %[[T2:.*]] = "test.test2"
+      // CHECK: fir.store %[[T2]]
+      %t3 = "test.test3"() : () -> i32
+      // CHECK: %[[T3:.*]] = "test.test3"
+      // CHECK-NOT: fir.store %[[T3]]
+      %t4 = "test.test4"() : () -> i32
+      // CHECK: %[[T4:.*]] = "test.test4"
+      // CHECK: fir.store %[[T4]]
+      %t5 = "test.test5"() : () -> i32
+      // CHECK: %[[T5:.*]] = "test.test5"
+      // CHECK: fir.store %[[T5]]
+      %t6 = "test.test6"() : () -> i32
+      // CHECK: %[[T6:.*]] = "test.test6"
+      // CHECK-NOT: fir.store %[[T6]]
+
+
+      "test.test1"(%t1) : (i32) -> ()
+      "test.test1"(%t2) : (i32) -> ()
+      "test.test1"(%t3) : (i32) -> ()
+
+      %true = arith.constant true
+      fir.if %true {
+        "test.test2"(%t3) : (i32) -> ()
+      }
+
+      %c1_i32 = arith.constant 1 : i32
+
+      %t5_pure_use = arith.addi %t5, %c1_i32 : i32
+
+      %t6_mem_effect_use = "test.test8"(%t6) : (i32) -> i32
+      // CHECK: %[[T6_USE:.*]] = "test.test8"
+      // CHECK: fir.store %[[T6_USE]]
+
+      %c42 = arith.constant 42 : index
+      %c1 = arith.constant 1 : index
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          "test.test10"(%t1) : (i32) -> ()
+          "test.test10"(%t5_pure_use) : (i32) -> ()
+          "test.test10"(%t6_mem_effect_use) : (i32) -> ()
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+
+      "test.test10"(%t2) : (i32) -> ()
+      fir.if %true {
+        "test.test10"(%t4) : (i32) -> ()
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}

>From a2fc6fa60690a38ed534443db0c0c8b8b8d4799a Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 20:24:36 +0900
Subject: [PATCH 20/23] Clean up trivially dead ops

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 32 ++++-------
 .../Transforms/OpenMP/lower-workshare3.mlir   |  2 +-
 .../Transforms/OpenMP/lower-workshare4.mlir   | 55 +++++++++++++++++++
 3 files changed, 68 insertions(+), 21 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare4.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 20f45296a8159a..a147db2cb5d59a 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -152,6 +152,14 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
   return false;
 }
 
+/// We clone pure operations in both the parallel and single blocks. this
+/// functions cleans them up if they end up with no uses
+static void cleanupBlock(Block *block) {
+  for (Operation &op : llvm::make_early_inc_range(*block))
+    if (isOpTriviallyDead(&op))
+      op.erase();
+}
+
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
   OpBuilder rootBuilder(sourceRegion.getContext());
@@ -258,13 +266,8 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         singleOperands.copyprivateVars =
             moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
                          singleBuilder, parallelBuilder);
+        cleanupBlock(singleBlock);
         for (auto var : singleOperands.copyprivateVars) {
-          Type ty;
-          if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
-            ty = firAlloca.getAllocatedType();
-          } else {
-            ty = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
-          }
           mlir::func::FuncOp funcOp =
               createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
           singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
@@ -302,6 +305,9 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 
     rootBuilder.clone(*block.getTerminator(), rootMapping);
   }
+
+  for (Block &targetBlock : targetRegion)
+    cleanupBlock(&targetBlock);
 }
 
 /// Lowers workshare to a sequence of single-thread regions and parallel loops
@@ -372,20 +378,6 @@ class LowerWorksharePass
 
       lowerWorkshare(wsOp);
     });
-
-    // Do folding
-    for (Operation *isolatedParent : parents) {
-      RewritePatternSet patterns(&getContext());
-      GreedyRewriteConfig config;
-      // prevent the pattern driver form merging blocks
-      config.enableRegionSimplification =
-          mlir::GreedySimplifyRegionLevel::Disabled;
-      if (failed(applyPatternsAndFoldGreedily(isolatedParent,
-                                              std::move(patterns), config))) {
-        emitError(isolatedParent->getLoc(), "error in lower workshare\n");
-        signalPassFailure();
-      }
-    }
   }
 };
 } // namespace
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
index 84eded94503282..aee95a464a31bd 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -3,7 +3,7 @@
 
 // tests if the correct values are stored
 
-func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+func.func @wsfunc() {
   omp.parallel {
   // CHECK: fir.alloca
   // CHECK: fir.alloca
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
new file mode 100644
index 00000000000000..6cff0075b4fe50
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -0,0 +1,55 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+      %t1 = "test.test1"() : () -> i32
+
+      %c1 = arith.constant 1 : index
+      %c42 = arith.constant 42 : index
+
+      %c2 = arith.constant 2 : index
+      "test.test3"(%c2) : (index) -> ()
+
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          "test.test2"() : () -> ()
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @wsfunc() {
+// CHECK:           %[[VAL_0:.*]] = fir.alloca i32
+// CHECK:           omp.parallel {
+// CHECK:             %[[VAL_1:.*]] = arith.constant true
+// CHECK:             fir.if %[[VAL_1]] {
+// CHECK:               omp.single {
+// CHECK:                 %[[VAL_2:.*]] = "test.test1"() : () -> i32
+// CHECK:                 %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK:                 "test.test3"(%[[VAL_3]]) : (index) -> ()
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_5:.*]] = arith.constant 42 : index
+// CHECK:               omp.wsloop nowait {
+// CHECK:                 omp.loop_nest (%[[VAL_6:.*]]) : index = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_4]]) {
+// CHECK:                   "test.test2"() : () -> ()
+// CHECK:                   omp.yield
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.barrier
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+

>From 1658a875460572bfe449840966b518e423594ac3 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 21:55:14 +0900
Subject: [PATCH 21/23] Only handle single-block regions for now

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp |  80 +++++----
 .../Transforms/OpenMP/lower-workshare.mlir    | 154 +++++++++---------
 .../Transforms/OpenMP/lower-workshare4.mlir   |  31 ++--
 3 files changed, 143 insertions(+), 122 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index a147db2cb5d59a..5998489c13d382 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -13,12 +13,14 @@
 #include <flang/Optimizer/Dialect/FIRType.h>
 #include <flang/Optimizer/HLFIR/HLFIROps.h>
 #include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/BreadthFirstIterator.h>
 #include <llvm/ADT/STLExtras.h>
 #include <llvm/ADT/SmallVectorExtras.h>
 #include <llvm/ADT/iterator_range.h>
 #include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
 #include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
 #include <mlir/IR/BuiltinOps.h>
@@ -161,7 +163,8 @@ static void cleanupBlock(Block *block) {
 }
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
-                              IRMapping &rootMapping, Location loc) {
+                              IRMapping &rootMapping, Location loc,
+                              mlir::DominanceInfo &di) {
   OpBuilder rootBuilder(sourceRegion.getContext());
   ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
   OpBuilder copyFuncBuilder(m.getBodyRegion());
@@ -214,14 +217,19 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     return copyPrivate;
   };
 
-  // TODO Need to handle these (clone them) in dominator tree order
   for (Block &block : sourceRegion) {
-    rootBuilder.createBlock(
+    Block *targetBlock = rootBuilder.createBlock(
         &targetRegion, {}, block.getArgumentTypes(),
         llvm::map_to_vector(block.getArguments(),
                             [](BlockArgument arg) { return arg.getLoc(); }));
-    Operation *terminator = block.getTerminator();
+    rootMapping.map(&block, targetBlock);
+    rootMapping.map(block.getArguments(), targetBlock->getArguments());
+  }
 
+  auto handleOneBlock = [&](Block &block) {
+    Block &targetBlock = *rootMapping.lookup(&block);
+    rootBuilder.setInsertionPointToStart(&targetBlock);
+    Operation *terminator = block.getTerminator();
     SmallVector<std::variant<SingleRegion, Operation *>> regions;
 
     auto it = block.begin();
@@ -298,12 +306,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
           for (auto [region, clonedRegion] :
                llvm::zip(op->getRegions(), cloned->getRegions()))
-            parallelizeRegion(region, clonedRegion, rootMapping, loc);
+            parallelizeRegion(region, clonedRegion, rootMapping, loc, di);
         }
       }
     }
 
     rootBuilder.clone(*block.getTerminator(), rootMapping);
+  };
+
+  if (sourceRegion.hasOneBlock()) {
+    handleOneBlock(sourceRegion.front());
+  } else {
+    auto &domTree = di.getDomTree(&sourceRegion);
+    for (auto node : llvm::breadth_first(domTree.getRootNode())) {
+      handleOneBlock(*node->getBlock());
+    }
   }
 
   for (Block &targetBlock : targetRegion)
@@ -336,47 +353,46 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 ///
 /// Note that we allocate temporary memory for values in omp.single's which need
 /// to be accessed in all threads in the closest omp.parallel
-void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) {
   Location loc = wsOp->getLoc();
   IRMapping rootMapping;
 
   OpBuilder rootBuilder(wsOp);
 
-  // TODO We need something like an scf;execute here, but that is not registered
-  // so using fir.if for now but it looks like it does not support multiple
-  // blocks so it doesnt work for multi block case...
-  auto ifOp = rootBuilder.create<fir::IfOp>(
-      loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
-  ifOp.getThenRegion().front().erase();
-
-  parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);
-
-  Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
-  assert(isa<omp::TerminatorOp>(terminatorOp));
-  OpBuilder termBuilder(terminatorOp);
-
+  // TODO We need something like an scf.execute here, but that is not registered
+  // so using omp.workshare as a placeholder. We need this op as our
+  // parallelizeRegion works on regions and not blocks.
+  omp::WorkshareOp newOp =
+      rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands());
   if (!wsOp.getNowait())
-    termBuilder.create<omp::BarrierOp>(loc);
-
-  termBuilder.create<fir::ResultOp>(loc, ValueRange());
-
-  terminatorOp->erase();
+    rootBuilder.create<omp::BarrierOp>(loc);
+
+  parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc, di);
+
+  if (wsOp.getRegion().getBlocks().size() != 1)
+    return failure();
+
+  // Inline the contents of the placeholder workshare op into its parent block.
+  Block *theBlock = &newOp.getRegion().front();
+  Operation *term = theBlock->getTerminator();
+  Block *parentBlock = wsOp->getBlock();
+  parentBlock->getOperations().splice(newOp->getIterator(),
+                                      theBlock->getOperations());
+  assert(term->getNumOperands() == 0);
+  term->erase();
+  newOp->erase();
   wsOp->erase();
-
-  return;
+  return success();
 }
 
 class LowerWorksharePass
     : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
 public:
   void runOnOperation() override {
-    SmallPtrSet<Operation *, 8> parents;
+    mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>();
     getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
-      Operation *isolatedParent =
-          wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
-      parents.insert(isolatedParent);
-
-      lowerWorkshare(wsOp);
+      if (failed(lowerWorkshare(wsOp, di)))
+        signalPassFailure();
     });
   }
 };
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 063d3865065e01..b31e951223d56f 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -86,43 +86,46 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant true
 // CHECK:           omp.parallel {
-// CHECK:             fir.if %[[VAL_4]] {
-// CHECK:               %[[VAL_5:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:               omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:                 %[[VAL_6:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:                 %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_6]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:                 %[[VAL_8:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:                 fir.store %[[VAL_8]] to %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:                 %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_6]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               omp.wsloop {
-// CHECK:                 omp.loop_nest (%[[VAL_14:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                   %[[VAL_15:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_14]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                   %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
-// CHECK:                   %[[VAL_17:.*]] = arith.subi %[[VAL_16]], %[[VAL_2]] : i32
-// CHECK:                   %[[VAL_18:.*]] = hlfir.designate %[[VAL_13]]#0 (%[[VAL_14]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                   hlfir.assign %[[VAL_17]] to %[[VAL_18]] temporary_lhs : i32, !fir.ref<i32>
-// CHECK:                   omp.yield
-// CHECK:                 }
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               omp.single nowait {
-// CHECK:                 hlfir.assign %[[VAL_13]]#0 to %[[VAL_11]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:                 fir.freemem %[[VAL_13]]#0 : !fir.heap<!fir.array<42xi32>>
-// CHECK:                 omp.terminator
+// CHECK:             %[[VAL_1:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               %[[VAL_2:.*]] = arith.constant 42 : index
+// CHECK:               %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_5]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_7:.*]] = arith.constant 42 : index
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : i32
+// CHECK:             %[[VAL_9:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_9]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_11:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_11]](%[[VAL_9]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_13:.*]] = arith.constant true
+// CHECK:             %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_14]]) to (%[[VAL_7]]) inclusive step (%[[VAL_14]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_10]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_8]] : i32
+// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                 omp.yield
 // CHECK:               }
-// CHECK:               omp.barrier
+// CHECK:               omp.terminator
 // CHECK:             }
+// CHECK:             omp.single nowait {
+// CHECK:               %[[VAL_20:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               %[[VAL_21:.*]] = fir.insert_value %[[VAL_20]], %[[VAL_13]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               hlfir.assign %[[VAL_12]]#0 to %[[VAL_10]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_22:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             %[[VAL_23:.*]] = fir.insert_value %[[VAL_22]], %[[VAL_13]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             omp.barrier
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return
@@ -146,46 +149,51 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_4:.*]] = arith.constant true
-// CHECK:           fir.if %[[VAL_4]] {
-// CHECK:             %[[VAL_5:.*]] = fir.alloca i32
-// CHECK:             %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:             omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_6]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:               fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i32>
-// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_9:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               fir.store %[[VAL_9]] to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               omp.terminator
-// CHECK:             }
-// CHECK:             %[[VAL_11:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:             %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_18:.*]] = fir.load %[[VAL_5]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_19:.*]] = arith.subi %[[VAL_17]], %[[VAL_18]] : i32
-// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_3]] : i32
-// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_20]] to %[[VAL_21]] temporary_lhs : i32, !fir.ref<i32>
-// CHECK:                 omp.yield
-// CHECK:               }
-// CHECK:               omp.terminator
-// CHECK:             }
-// CHECK:             omp.single nowait {
-// CHECK:               "test.test1"(%[[VAL_5]]) : (!fir.ref<i32>) -> ()
-// CHECK:               hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
-// CHECK:               omp.terminator
+// CHECK:           %[[VAL_1:.*]] = fir.alloca i32
+// CHECK:           %[[VAL_2:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_2]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_1]] : !fir.ref<i32>
+// CHECK:             %[[VAL_4:.*]] = arith.constant 42 : index
+// CHECK:             %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_7:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:             fir.store %[[VAL_7]] to %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_7]](%[[VAL_5]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           %[[VAL_9:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_10:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_11:.*]] = fir.shape %[[VAL_10]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_13:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_15:.*]] = arith.constant true
+// CHECK:           %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:           omp.wsloop {
+// CHECK:             omp.loop_nest (%[[VAL_17:.*]]) : index = (%[[VAL_16]]) to (%[[VAL_10]]) inclusive step (%[[VAL_16]]) {
+// CHECK:               %[[VAL_18:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_17]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:               %[[VAL_19:.*]] = fir.load %[[VAL_18]] : !fir.ref<i32>
+// CHECK:               %[[VAL_20:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
+// CHECK:               %[[VAL_21:.*]] = arith.subi %[[VAL_19]], %[[VAL_20]] : i32
+// CHECK:               %[[VAL_22:.*]] = arith.subi %[[VAL_21]], %[[VAL_9]] : i32
+// CHECK:               %[[VAL_23:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_17]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:               hlfir.assign %[[VAL_22]] to %[[VAL_23]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.yield
 // CHECK:             }
-// CHECK:             omp.barrier
+// CHECK:             omp.terminator
 // CHECK:           }
+// CHECK:           omp.single nowait {
+// CHECK:             %[[VAL_24:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             %[[VAL_25:.*]] = fir.insert_value %[[VAL_24]], %[[VAL_15]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             "test.test1"(%[[VAL_1]]) : (!fir.ref<i32>) -> ()
+// CHECK:             hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:             fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           %[[VAL_26:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:           %[[VAL_27:.*]] = fir.insert_value %[[VAL_26]], %[[VAL_15]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:           omp.barrier
 // CHECK:           return
 // CHECK:         }
+
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
index 6cff0075b4fe50..d695a1c354517b 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare4.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -29,25 +29,22 @@ func.func @wsfunc() {
 // CHECK-LABEL:   func.func @wsfunc() {
 // CHECK:           %[[VAL_0:.*]] = fir.alloca i32
 // CHECK:           omp.parallel {
-// CHECK:             %[[VAL_1:.*]] = arith.constant true
-// CHECK:             fir.if %[[VAL_1]] {
-// CHECK:               omp.single {
-// CHECK:                 %[[VAL_2:.*]] = "test.test1"() : () -> i32
-// CHECK:                 %[[VAL_3:.*]] = arith.constant 2 : index
-// CHECK:                 "test.test3"(%[[VAL_3]]) : (index) -> ()
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:               %[[VAL_5:.*]] = arith.constant 42 : index
-// CHECK:               omp.wsloop nowait {
-// CHECK:                 omp.loop_nest (%[[VAL_6:.*]]) : index = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_4]]) {
-// CHECK:                   "test.test2"() : () -> ()
-// CHECK:                   omp.yield
-// CHECK:                 }
-// CHECK:                 omp.terminator
+// CHECK:             omp.single {
+// CHECK:               %[[VAL_1:.*]] = "test.test1"() : () -> i32
+// CHECK:               %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:               "test.test3"(%[[VAL_2]]) : (index) -> ()
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_4:.*]] = arith.constant 42 : index
+// CHECK:             omp.wsloop nowait {
+// CHECK:               omp.loop_nest (%[[VAL_5:.*]]) : index = (%[[VAL_3]]) to (%[[VAL_4]]) inclusive step (%[[VAL_3]]) {
+// CHECK:                 "test.test2"() : () -> ()
+// CHECK:                 omp.yield
 // CHECK:               }
-// CHECK:               omp.barrier
+// CHECK:               omp.terminator
 // CHECK:             }
+// CHECK:             omp.barrier
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return

>From b5aadf53b32d186e5f31f561edf13c3b4f41005b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 6 Aug 2024 13:52:20 +0900
Subject: [PATCH 22/23] Fix tests for custom assembly for loop wrapper

---
 flang/test/Transforms/OpenMP/lower-workshare.mlir  | 8 ++++----
 flang/test/Transforms/OpenMP/lower-workshare3.mlir | 4 ++--
 flang/test/Transforms/OpenMP/lower-workshare4.mlir | 4 ++--
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index b31e951223d56f..9347863dc4a609 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -13,7 +13,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
       %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
       %true = arith.constant true
       %c1 = arith.constant 1 : index
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
           %8 = fir.load %7 : !fir.ref<i32>
@@ -23,7 +23,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
       %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
       %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
       %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
@@ -52,7 +52,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
     %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
     %true = arith.constant true
     %c1 = arith.constant 1 : index
-    "omp.workshare_loop_wrapper"() ({
+    omp.workshare_loop_wrapper {
       omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
         %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
         %8 = fir.load %7 : !fir.ref<i32>
@@ -64,7 +64,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
         omp.yield
       }
       omp.terminator
-    }) : () -> ()
+    }
     %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
     %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
     %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
index aee95a464a31bd..afb41d95e7198e 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -52,7 +52,7 @@ func.func @wsfunc() {
 
       %c42 = arith.constant 42 : index
       %c1 = arith.constant 1 : index
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           "test.test10"(%t1) : (i32) -> ()
           "test.test10"(%t5_pure_use) : (i32) -> ()
@@ -60,7 +60,7 @@ func.func @wsfunc() {
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
 
       "test.test10"(%t2) : (i32) -> ()
       fir.if %true {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
index d695a1c354517b..0a70007a9e78dd 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare4.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -12,13 +12,13 @@ func.func @wsfunc() {
       %c2 = arith.constant 2 : index
       "test.test3"(%c2) : (index) -> ()
 
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           "test.test2"() : () -> ()
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
       omp.terminator
     }
     omp.terminator

>From 08584697956e12be8a83e93de795e4594aabd1e3 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 14:43:50 +0900
Subject: [PATCH 23/23] Only run the lower workshare pass if openmp is enabled

---
 flang/include/flang/Tools/CLOptions.inc      |  7 ++++---
 flang/include/flang/Tools/CrossToolHelpers.h |  1 +
 flang/lib/Frontend/FrontendActions.cpp       | 10 +++++++++-
 flang/tools/bbc/bbc.cpp                      |  5 ++++-
 flang/tools/tco/tco.cpp                      |  1 +
 5 files changed, 19 insertions(+), 5 deletions(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index a565effebfa92e..fbca7b6838ada7 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -328,7 +328,7 @@ inline void createDefaultFIROptimizerPassPipeline(
 /// \param optLevel - optimization level used for creating FIR optimization
 ///   passes pipeline
 inline void createHLFIRToFIRPassPipeline(
-    mlir::PassManager &pm, llvm::OptimizationLevel optLevel = defaultOptLevel) {
+    mlir::PassManager &pm, bool enableOpenMP, llvm::OptimizationLevel optLevel = defaultOptLevel) {
   if (optLevel.isOptimizingForSpeed()) {
     addCanonicalizerPassWithoutRegionSimplification(pm);
     addNestedPassToAllTopLevelOperations(
@@ -345,7 +345,8 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  pm.addPass(flangomp::createLowerWorkshare());
+  if (enableOpenMP)
+    pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
@@ -416,7 +417,7 @@ inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
 ///   passes pipeline
 inline void createMLIRToLLVMPassPipeline(mlir::PassManager &pm,
     MLIRToLLVMPassPipelineConfig &config, llvm::StringRef inputFilename = {}) {
-  fir::createHLFIRToFIRPassPipeline(pm, config.OptLevel);
+  fir::createHLFIRToFIRPassPipeline(pm, config.EnableOpenMP, config.OptLevel);
 
   // Add default optimizer pass pipeline.
   fir::createDefaultFIROptimizerPassPipeline(pm, config);
diff --git a/flang/include/flang/Tools/CrossToolHelpers.h b/flang/include/flang/Tools/CrossToolHelpers.h
index 1d890fd8e1f6f9..dd258864ff7f22 100644
--- a/flang/include/flang/Tools/CrossToolHelpers.h
+++ b/flang/include/flang/Tools/CrossToolHelpers.h
@@ -123,6 +123,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks {
       false; ///< Set no-signed-zeros-fp-math attribute for functions.
   bool UnsafeFPMath = false; ///< Set unsafe-fp-math attribute for functions.
   bool NSWOnLoopVarInc = false; ///< Add nsw flag to loop variable increments.
+  bool EnableOpenMP = false; ///< Enable OpenMP lowering.
 };
 
 struct OffloadModuleOpts {
diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index 5c86bd947ce73f..db5c5649337528 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -711,7 +711,11 @@ void CodeGenAction::lowerHLFIRToFIR() {
   pm.enableVerifier(/*verifyPasses=*/true);
 
   // Create the pass pipeline
-  fir::createHLFIRToFIRPassPipeline(pm, level);
+  fir::createHLFIRToFIRPassPipeline(
+      pm,
+      ci.getInvocation().getFrontendOpts().features.IsEnabled(
+          Fortran::common::LanguageFeature::OpenMP),
+      level);
   (void)mlir::applyPassManagerCLOptions(pm);
 
   if (!mlir::succeeded(pm.run(*mlirModule))) {
@@ -824,6 +828,10 @@ void CodeGenAction::generateLLVMIR() {
     config.VScaleMax = vsr->second;
   }
 
+  if (ci.getInvocation().getFrontendOpts().features.IsEnabled(
+          Fortran::common::LanguageFeature::OpenMP))
+    config.EnableOpenMP = true;
+
   if (ci.getInvocation().getLoweringOpts().getNSWOnLoopVarInc())
     config.NSWOnLoopVarInc = true;
 
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index 07eef065daf6f4..681b23883df44a 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -429,7 +429,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
 
     if (emitFIR && useHLFIR) {
       // lower HLFIR to FIR
-      fir::createHLFIRToFIRPassPipeline(pm, llvm::OptimizationLevel::O2);
+      fir::createHLFIRToFIRPassPipeline(pm, enableOpenMP,
+                                        llvm::OptimizationLevel::O2);
       if (mlir::failed(pm.run(mlirModule))) {
         llvm::errs() << "FATAL: lowering from HLFIR to FIR failed";
         return mlir::failure();
@@ -444,6 +445,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
 
     // Add O2 optimizer pass pipeline.
     MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
+    if (enableOpenMP)
+      config.EnableOpenMP = true;
     config.NSWOnLoopVarInc = setNSW;
     fir::registerDefaultInlinerPass(config);
     fir::createDefaultFIROptimizerPassPipeline(pm, config);
diff --git a/flang/tools/tco/tco.cpp b/flang/tools/tco/tco.cpp
index a8c64333109aeb..06892cdc3f6a80 100644
--- a/flang/tools/tco/tco.cpp
+++ b/flang/tools/tco/tco.cpp
@@ -138,6 +138,7 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
       return mlir::failure();
   } else {
     MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
+    config.EnableOpenMP = true;  // assume the input contains OpenMP
     config.AliasAnalysis = true; // enabled when optimizing for speed
     if (codeGenLLVM) {
       // Run only CodeGen passes.



More information about the llvm-branch-commits mailing list