[llvm-branch-commits] [flang] [mlir] [flang] Lower omp.workshare to other omp constructs (PR #101446)

Ivan R. Ivanov via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Aug 5 21:53:13 PDT 2024


https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/101446

>From cc551ca6efe69d0f3e8b719766e6190b59623b42 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 6 Aug 2024 13:41:22 +0900
Subject: [PATCH 01/23] Add assembly format for wrapper and add test

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  2 +-
 mlir/test/Dialect/OpenMP/ops.mlir             | 61 +++++++++++++++++++
 2 files changed, 62 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 863cd81923c87..0f29e911cb2f2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -313,10 +313,10 @@ def WorkshareLoopWrapperOp : OpenMP_Op<"workshare_loop_wrapper", traits = [
     RecursiveMemoryEffects, SingleBlock
   ], singleRegion = true> {
   let summary = "contains loop nests to be parallelized by workshare";
-
   let builders = [
     OpBuilder<(ins), [{ build($_builder, $_state, {}); }]>
   ];
+  let assemblyFormat = "$region attr-dict";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d2924998f41b8..981e3fbb0306b 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2789,3 +2789,64 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
 
   return
 }
+
+// CHECK-LABEL: func @omp_workshare
+func.func @omp_workshare() {
+  // CHECK: omp.workshare {
+  omp.workshare {
+    "test.payload"() : () -> ()
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare_nowait
+func.func @omp_workshare_nowait() {
+  // CHECK: omp.workshare nowait {
+  omp.workshare nowait {
+    "test.payload"() : () -> ()
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare_multiple_blocks
+func.func @omp_workshare_multiple_blocks() {
+  // CHECK: omp.workshare {
+  omp.workshare {
+    cf.br ^bb2
+    ^bb2:
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare_loop_wrapper
+func.func @omp_workshare_loop_wrapper(%idx : index) {
+  // CHECK-NEXT: omp.workshare_loop_wrapper
+  omp.workshare_loop_wrapper {
+    // CHECK-NEXT: omp.loop_nest
+    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+      omp.yield
+    }
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare_loop_wrapper_attrs
+func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
+  // CHECK-NEXT: omp.workshare_loop_wrapper {
+  omp.workshare_loop_wrapper {
+    // CHECK-NEXT: omp.loop_nest
+    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+      omp.yield
+    }
+    omp.terminator
+  // CHECK: } {attr_in_dict}
+  } {attr_in_dict}
+  return
+}

>From aea7da1492e9ad122b3b17f07b8f91b7c6eac777 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:11:47 +0900
Subject: [PATCH 02/23] [flang][omp] Emit omp.workshare in frontend

---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 30 ++++++++++++++++++++++++++----
 1 file changed, 26 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 2b1839b5270d4..f7bc565ea8cbc 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1270,6 +1270,15 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter,
       loc, llvm::omp::Directive::OMPD_taskwait);
 }
 
+static void genWorkshareClauses(lower::AbstractConverter &converter,
+                                semantics::SemanticsContext &semaCtx,
+                                lower::StatementContext &stmtCtx,
+                                const List<Clause> &clauses, mlir::Location loc,
+                                mlir::omp::WorkshareOperands &clauseOps) {
+  ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processNowait(clauseOps);
+}
+
 static void genTeamsClauses(lower::AbstractConverter &converter,
                             semantics::SemanticsContext &semaCtx,
                             lower::StatementContext &stmtCtx,
@@ -1890,6 +1899,22 @@ genTaskyieldOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   return converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(loc);
 }
 
+static mlir::omp::WorkshareOp
+genWorkshareOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+           semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+           mlir::Location loc, const ConstructQueue &queue,
+           ConstructQueue::iterator item) {
+  lower::StatementContext stmtCtx;
+  mlir::omp::WorkshareOperands clauseOps;
+  genWorkshareClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
+
+  return genOpWithBody<mlir::omp::WorkshareOp>(
+      OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                        llvm::omp::Directive::OMPD_workshare)
+          .setClauses(&item->clauses),
+      queue, item, clauseOps);
+}
+
 static mlir::omp::TeamsOp
 genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
            semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
@@ -2249,10 +2274,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                   llvm::omp::getOpenMPDirectiveName(dir) + ")");
   // case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_workshare:
-    // FIXME: Workshare is not a commonly used OpenMP construct, an
-    // implementation for this feature will come later. For the codes
-    // that use this construct, add a single construct for now.
-    genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genWorkshareOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
 
   // Composite constructs

>From 9046df2cebff7e06e803a7db6df506dc67c6edce Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:02:37 +0900
Subject: [PATCH 03/23] Fix lower test for workshare

---
 flang/test/Lower/OpenMP/workshare.f90 | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/flang/test/Lower/OpenMP/workshare.f90 b/flang/test/Lower/OpenMP/workshare.f90
index 1e11677a15e1f..8e771952f5b6d 100644
--- a/flang/test/Lower/OpenMP/workshare.f90
+++ b/flang/test/Lower/OpenMP/workshare.f90
@@ -6,7 +6,7 @@ subroutine sb1(arr)
   integer :: arr(:)
 !CHECK: omp.parallel  {
   !$omp parallel
-!CHECK: omp.single  {
+!CHECK: omp.workshare {
   !$omp workshare
     arr = 0
   !$omp end workshare
@@ -20,7 +20,7 @@ subroutine sb2(arr)
   integer :: arr(:)
 !CHECK: omp.parallel  {
   !$omp parallel
-!CHECK: omp.single nowait {
+!CHECK: omp.workshare nowait {
   !$omp workshare
     arr = 0
   !$omp end workshare nowait
@@ -33,7 +33,7 @@ subroutine sb2(arr)
 subroutine sb3(arr)
   integer :: arr(:)
 !CHECK: omp.parallel  {
-!CHECK: omp.single  {
+!CHECK: omp.workshare  {
   !$omp parallel workshare
     arr = 0
   !$omp end parallel workshare

>From cca8588b2b2d1c4275b9eaf22c7bd017942f06d9 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:12:34 +0900
Subject: [PATCH 04/23] [flang] Introduce ws loop nest generation for HLFIR
 lowering

---
 .../flang/Optimizer/Builder/HLFIRTools.h      | 12 +++--
 flang/lib/Lower/ConvertCall.cpp               |  2 +-
 flang/lib/Lower/OpenMP/ReductionProcessor.cpp |  4 +-
 flang/lib/Optimizer/Builder/HLFIRTools.cpp    | 52 ++++++++++++++-----
 .../HLFIR/Transforms/BufferizeHLFIR.cpp       |  3 +-
 .../LowerHLFIROrderedAssignments.cpp          | 30 +++++------
 .../Transforms/OptimizedBufferization.cpp     |  6 +--
 7 files changed, 69 insertions(+), 40 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 6b41025eea078..14e42c6f358e4 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -357,8 +357,8 @@ hlfir::ElementalOp genElementalOp(
 
 /// Structure to describe a loop nest.
 struct LoopNest {
-  fir::DoLoopOp outerLoop;
-  fir::DoLoopOp innerLoop;
+  mlir::Operation *outerOp;
+  mlir::Block *body;
   llvm::SmallVector<mlir::Value> oneBasedIndices;
 };
 
@@ -366,11 +366,13 @@ struct LoopNest {
 /// \p isUnordered specifies whether the loops in the loop nest
 /// are unordered.
 LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
-                     mlir::ValueRange extents, bool isUnordered = false);
+                     mlir::ValueRange extents, bool isUnordered = false,
+                     bool emitWsLoop = false);
 inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
-                            mlir::Value shape, bool isUnordered = false) {
+                            mlir::Value shape, bool isUnordered = false,
+                            bool emitWsLoop = false) {
   return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
-                     isUnordered);
+                     isUnordered, emitWsLoop);
 }
 
 /// Inline the body of an hlfir.elemental at the current insertion point
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index fd873f55dd844..0689d6e033dd9 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -2128,7 +2128,7 @@ class ElementalCallBuilder {
           hlfir::genLoopNest(loc, builder, shape, !mustBeOrdered);
       mlir::ValueRange oneBasedIndices = loopNest.oneBasedIndices;
       auto insPt = builder.saveInsertionPoint();
-      builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+      builder.setInsertionPointToStart(loopNest.body);
       callContext.stmtCtx.pushScope();
       for (auto &preparedActual : loweredActuals)
         if (preparedActual)
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c3c1f363033c2..72a90dd0d6f29 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -375,7 +375,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
   // know this won't miss any opportuinties for clever elemental inlining
   hlfir::LoopNest nest = hlfir::genLoopNest(
       loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
-  builder.setInsertionPointToStart(nest.innerLoop.getBody());
+  builder.setInsertionPointToStart(nest.body);
   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
@@ -389,7 +389,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
       builder, loc, redId, refTy, lhsEle, rhsEle);
   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
 
-  builder.setInsertionPointAfter(nest.outerLoop);
+  builder.setInsertionPointAfter(nest.outerOp);
   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
 }
 
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 8d0ae2f195178..cd07cb741eb4b 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <optional>
 
 // Return explicit extents. If the base is a fir.box, this won't read it to
@@ -855,26 +856,51 @@ mlir::Value hlfir::inlineElementalOp(
 
 hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
                                    fir::FirOpBuilder &builder,
-                                   mlir::ValueRange extents, bool isUnordered) {
+                                   mlir::ValueRange extents, bool isUnordered,
+                                   bool emitWsLoop) {
   hlfir::LoopNest loopNest;
   assert(!extents.empty() && "must have at least one extent");
-  auto insPt = builder.saveInsertionPoint();
+  mlir::OpBuilder::InsertionGuard guard(builder);
   loopNest.oneBasedIndices.assign(extents.size(), mlir::Value{});
   // Build loop nest from column to row.
   auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
   mlir::Type indexType = builder.getIndexType();
-  unsigned dim = extents.size() - 1;
-  for (auto extent : llvm::reverse(extents)) {
-    auto ub = builder.createConvert(loc, indexType, extent);
-    loopNest.innerLoop =
-        builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
-    builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
-    // Reverse the indices so they are in column-major order.
-    loopNest.oneBasedIndices[dim--] = loopNest.innerLoop.getInductionVar();
-    if (!loopNest.outerLoop)
-      loopNest.outerLoop = loopNest.innerLoop;
+  if (emitWsLoop) {
+    auto wsloop = builder.create<mlir::omp::WsloopOp>(
+        loc, mlir::ArrayRef<mlir::NamedAttribute>());
+    loopNest.outerOp = wsloop;
+    builder.createBlock(&wsloop.getRegion());
+    mlir::omp::LoopNestOperands lnops;
+    lnops.loopInclusive = builder.getUnitAttr();
+    for (auto extent : llvm::reverse(extents)) {
+      lnops.loopLowerBounds.push_back(one);
+      lnops.loopUpperBounds.push_back(extent);
+      lnops.loopSteps.push_back(one);
+    }
+    auto lnOp = builder.create<mlir::omp::LoopNestOp>(loc, lnops);
+    builder.create<mlir::omp::TerminatorOp>(loc);
+    mlir::Block *block = builder.createBlock(&lnOp.getRegion());
+    for (auto extent : llvm::reverse(extents))
+      block->addArgument(extent.getType(), extent.getLoc());
+    loopNest.body = block;
+    builder.create<mlir::omp::YieldOp>(loc);
+    for (unsigned dim = 0; dim < extents.size(); dim++)
+      loopNest.oneBasedIndices[extents.size() - dim - 1] =
+          lnOp.getRegion().front().getArgument(dim);
+  } else {
+    unsigned dim = extents.size() - 1;
+    for (auto extent : llvm::reverse(extents)) {
+      auto ub = builder.createConvert(loc, indexType, extent);
+      auto doLoop =
+          builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
+      loopNest.body = doLoop.getBody();
+      builder.setInsertionPointToStart(loopNest.body);
+      // Reverse the indices so they are in column-major order.
+      loopNest.oneBasedIndices[dim--] = doLoop.getInductionVar();
+      if (!loopNest.outerOp)
+        loopNest.outerOp = doLoop;
+    }
   }
-  builder.restoreInsertionPoint(insPt);
   return loopNest;
 }
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index a70a6b388c4b1..b608677c52631 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace hlfir {
@@ -793,7 +794,7 @@ struct ElementalOpConversion
     hlfir::LoopNest loopNest =
         hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
     auto insPt = builder.saveInsertionPoint();
-    builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+    builder.setInsertionPointToStart(loopNest.body);
     auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
                                           loopNest.oneBasedIndices);
     hlfir::Entity elementValue(yield.getElementValue());
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
index 85dd517cb5791..645abf65d10a3 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
@@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
       // if the LHS is not).
       mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
       elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
-      builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
+      builder.setInsertionPointToStart(elementalLoopNest->body);
       lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
                                       elementalLoopNest->oneBasedIndices);
       rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
@@ -484,7 +484,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
     for (auto &cleanupConversion : argConversionCleanups)
       cleanupConversion();
     if (elementalLoopNest)
-      builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
+      builder.setInsertionPointAfter(elementalLoopNest->outerOp);
   } else {
     // TODO: preserve allocatable assignment aspects for forall once
     // they are conveyed in hlfir.region_assign.
@@ -493,7 +493,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
   generateCleanupIfAny(loweredLhs.elementalCleanup);
   if (loweredLhs.vectorSubscriptLoopNest)
     builder.setInsertionPointAfter(
-        loweredLhs.vectorSubscriptLoopNest->outerLoop);
+        loweredLhs.vectorSubscriptLoopNest->outerOp);
   generateCleanupIfAny(oldRhsYield);
   generateCleanupIfAny(loweredLhs.nonElementalCleanup);
 }
@@ -518,8 +518,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
       hlfir::Entity savedMask{maybeSaved->first};
       mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
       whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
-      constructStack.push_back(whereLoopNest->outerLoop.getOperation());
-      builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
+      constructStack.push_back(whereLoopNest->outerOp);
+      builder.setInsertionPointToStart(whereLoopNest->body);
       mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
                                             whereLoopNest->oneBasedIndices);
       generateMaskIfOp(cdt);
@@ -527,7 +527,7 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
         // If this is the same run as the one that saved the value, the clean-up
         // was left-over to be done now.
         auto insertionPoint = builder.saveInsertionPoint();
-        builder.setInsertionPointAfter(whereLoopNest->outerLoop);
+        builder.setInsertionPointAfter(whereLoopNest->outerOp);
         generateCleanupIfAny(maybeSaved->second);
         builder.restoreInsertionPoint(insertionPoint);
       }
@@ -539,8 +539,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
     mask.generateNoneElementalPart(builder, mapper);
     mlir::Value shape = mask.generateShape(builder, mapper);
     whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
-    constructStack.push_back(whereLoopNest->outerLoop.getOperation());
-    builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
+    constructStack.push_back(whereLoopNest->outerOp);
+    builder.setInsertionPointToStart(whereLoopNest->body);
     mlir::Value cdt = generateMaskedEntity(mask);
     generateMaskIfOp(cdt);
     return;
@@ -754,7 +754,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
       loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
           loc, builder, loweredLhs.vectorSubscriptShape.value());
       builder.setInsertionPointToStart(
-          loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
+          loweredLhs.vectorSubscriptLoopNest->body);
     }
     loweredLhs.lhs = temp->second.fetch(loc, builder);
     return loweredLhs;
@@ -772,7 +772,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
         hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
                            !elementalAddrLhs.isOrdered());
     builder.setInsertionPointToStart(
-        loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
+        loweredLhs.vectorSubscriptLoopNest->body);
     mapper.map(elementalAddrLhs.getIndices(),
                loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
     for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
@@ -798,11 +798,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
   if (!maskedExpr.noneElementalPartWasGenerated) {
     // Generate none elemental part before the where loops (but inside the
     // current forall loops if any).
-    builder.setInsertionPoint(whereLoopNest->outerLoop);
+    builder.setInsertionPoint(whereLoopNest->outerOp);
     maskedExpr.generateNoneElementalPart(builder, mapper);
   }
   // Generate the none elemental part cleanup after the where loops.
-  builder.setInsertionPointAfter(whereLoopNest->outerLoop);
+  builder.setInsertionPointAfter(whereLoopNest->outerOp);
   maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
   // Generate the value of the current element for the masked expression
   // at the current insertion point (inside the where loops, and any fir.if
@@ -1242,7 +1242,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
   LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
   fir::factory::TemporaryStorage *temp = nullptr;
   if (loweredLhs.vectorSubscriptLoopNest)
-    constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
+    constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
   if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
     // Vector subscripted entity for which the shape must also be saved on top
     // of the element addresses (e.g. the shape may change in each forall
@@ -1265,7 +1265,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
     // subscripted LHS.
     auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
     auto insertionPoint = builder.saveInsertionPoint();
-    builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
+    builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
     vectorTmp.pushShape(loc, builder, shape);
     builder.restoreInsertionPoint(insertionPoint);
   } else {
@@ -1291,7 +1291,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
   if (loweredLhs.vectorSubscriptLoopNest) {
     constructStack.pop_back();
     builder.setInsertionPointAfter(
-        loweredLhs.vectorSubscriptLoopNest->outerLoop);
+        loweredLhs.vectorSubscriptLoopNest->outerOp);
   }
 }
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index c5b809514c54c..c4aed6b79df92 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -483,7 +483,7 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
   // hlfir.elemental region inside the inner loop
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
                                         loopNest.oneBasedIndices);
   hlfir::Entity elementValue{yield.getElementValue()};
@@ -554,7 +554,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
       hlfir::getIndexExtents(loc, builder, shape);
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto arrayElement =
       hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
   builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
@@ -649,7 +649,7 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
       hlfir::getIndexExtents(loc, builder, shape);
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto rhsArrayElement =
       hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
   rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);

>From a47ea24f2f30be0dcaf7997287a5a3ba4ca83161 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:08:34 +0900
Subject: [PATCH 05/23] Emit loop nests in a custom wrapper

---
 flang/include/flang/Optimizer/Builder/HLFIRTools.h |  6 +++---
 flang/lib/Optimizer/Builder/HLFIRTools.cpp         | 11 +++++------
 2 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 14e42c6f358e4..6987471957218 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -367,12 +367,12 @@ struct LoopNest {
 /// are unordered.
 LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
                      mlir::ValueRange extents, bool isUnordered = false,
-                     bool emitWsLoop = false);
+                     bool emitWorkshareLoop = false);
 inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
                             mlir::Value shape, bool isUnordered = false,
-                            bool emitWsLoop = false) {
+                            bool emitWorkshareLoop = false) {
   return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
-                     isUnordered, emitWsLoop);
+                     isUnordered, emitWorkshareLoop);
 }
 
 /// Inline the body of an hlfir.elemental at the current insertion point
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index cd07cb741eb4b..91b1b3d774a01 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -857,7 +857,7 @@ mlir::Value hlfir::inlineElementalOp(
 hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
                                    fir::FirOpBuilder &builder,
                                    mlir::ValueRange extents, bool isUnordered,
-                                   bool emitWsLoop) {
+                                   bool emitWorkshareLoop) {
   hlfir::LoopNest loopNest;
   assert(!extents.empty() && "must have at least one extent");
   mlir::OpBuilder::InsertionGuard guard(builder);
@@ -865,11 +865,10 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
   // Build loop nest from column to row.
   auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
   mlir::Type indexType = builder.getIndexType();
-  if (emitWsLoop) {
-    auto wsloop = builder.create<mlir::omp::WsloopOp>(
-        loc, mlir::ArrayRef<mlir::NamedAttribute>());
-    loopNest.outerOp = wsloop;
-    builder.createBlock(&wsloop.getRegion());
+  if (emitWorkshareLoop) {
+    auto wslw = builder.create<mlir::omp::WorkshareLoopWrapperOp>(loc);
+    loopNest.outerOp = wslw;
+    builder.createBlock(&wslw.getRegion());
     mlir::omp::LoopNestOperands lnops;
     lnops.loopInclusive = builder.getUnitAttr();
     for (auto extent : llvm::reverse(extents)) {

>From f1b9e862b4af5c236983ce44f7e8bca5b1ea376f Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 22:05:47 +0900
Subject: [PATCH 06/23] Only emit unordered loops as omp loops

---
 flang/lib/Optimizer/Builder/HLFIRTools.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 91b1b3d774a01..333331378841e 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -858,6 +858,7 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
                                    fir::FirOpBuilder &builder,
                                    mlir::ValueRange extents, bool isUnordered,
                                    bool emitWorkshareLoop) {
+  emitWorkshareLoop = emitWorkshareLoop && isUnordered;
   hlfir::LoopNest loopNest;
   assert(!extents.empty() && "must have at least one extent");
   mlir::OpBuilder::InsertionGuard guard(builder);

>From 33a5f19d31b6cda101fce7029de598b0c9889853 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 22:06:55 +0900
Subject: [PATCH 07/23] [flang] Lower omp.workshare to other omp constructs

---
 flang/include/flang/Optimizer/CMakeLists.txt  |   1 +
 .../flang/Optimizer/OpenMP/CMakeLists.txt     |   4 +
 flang/include/flang/Optimizer/OpenMP/Passes.h |  30 ++
 .../include/flang/Optimizer/OpenMP/Passes.td  |  18 ++
 flang/include/flang/Tools/CLOptions.inc       |   2 +
 flang/lib/Frontend/CMakeLists.txt             |   1 +
 flang/lib/Optimizer/CMakeLists.txt            |   1 +
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |  26 ++
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 259 ++++++++++++++++++
 flang/test/HLFIR/bufferize-workshare.fir      |  58 ++++
 .../Transforms/OpenMP/lower-workshare.mlir    |  81 ++++++
 .../Transforms/OpenMP/lower-workshare5.mlir   |  42 +++
 flang/tools/bbc/CMakeLists.txt                |   1 +
 flang/tools/fir-opt/CMakeLists.txt            |   1 +
 flang/tools/fir-opt/fir-opt.cpp               |   2 +
 flang/tools/tco/CMakeLists.txt                |   1 +
 16 files changed, 528 insertions(+)
 create mode 100644 flang/include/flang/Optimizer/OpenMP/CMakeLists.txt
 create mode 100644 flang/include/flang/Optimizer/OpenMP/Passes.h
 create mode 100644 flang/include/flang/Optimizer/OpenMP/Passes.td
 create mode 100644 flang/lib/Optimizer/OpenMP/CMakeLists.txt
 create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
 create mode 100644 flang/test/HLFIR/bufferize-workshare.fir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare.mlir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare5.mlir

diff --git a/flang/include/flang/Optimizer/CMakeLists.txt b/flang/include/flang/Optimizer/CMakeLists.txt
index 89e43a9ee8d62..3336ac935e101 100644
--- a/flang/include/flang/Optimizer/CMakeLists.txt
+++ b/flang/include/flang/Optimizer/CMakeLists.txt
@@ -2,3 +2,4 @@ add_subdirectory(CodeGen)
 add_subdirectory(Dialect)
 add_subdirectory(HLFIR)
 add_subdirectory(Transforms)
+add_subdirectory(OpenMP)
diff --git a/flang/include/flang/Optimizer/OpenMP/CMakeLists.txt b/flang/include/flang/Optimizer/OpenMP/CMakeLists.txt
new file mode 100644
index 0000000000000..d59573f0f7fd9
--- /dev/null
+++ b/flang/include/flang/Optimizer/OpenMP/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name FlangOpenMP)
+
+add_public_tablegen_target(FlangOpenMPPassesIncGen)
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.h b/flang/include/flang/Optimizer/OpenMP/Passes.h
new file mode 100644
index 0000000000000..95a05b3005073
--- /dev/null
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.h
@@ -0,0 +1,30 @@
+//===- Passes.h - OpenMP pass entry points ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header declares OpenMP pass entry points.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_OPENMP_PASSES_H
+#define FORTRAN_OPTIMIZER_OPENMP_PASSES_H
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include <memory>
+
+namespace flangomp {
+#define GEN_PASS_DECL
+#define GEN_PASS_REGISTRATION
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+
+bool shouldUseWorkshareLowering(mlir::Operation *op);
+
+} // namespace flangomp
+
+#endif // FORTRAN_OPTIMIZER_OPENMP_PASSES_H
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
new file mode 100644
index 0000000000000..6f636ec1df616
--- /dev/null
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -0,0 +1,18 @@
+//===-- Passes.td - HLFIR pass definition file -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_DIALECT_OPENMP_PASSES
+#define FORTRAN_DIALECT_OPENMP_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def LowerWorkshare : Pass<"lower-workshare"> {
+  let summary = "Lower workshare construct";
+}
+
+#endif //FORTRAN_DIALECT_OPENMP_PASSES
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 7df5044949463..594369fc2ffe5 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -17,6 +17,7 @@
 #include "mlir/Transforms/Passes.h"
 #include "flang/Optimizer/CodeGen/CodeGen.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "llvm/Passes/OptimizationLevel.h"
 #include "llvm/Support/CommandLine.h"
@@ -344,6 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
+  pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/lib/Frontend/CMakeLists.txt b/flang/lib/Frontend/CMakeLists.txt
index c20b9096aff49..ecdcc73d61ec1 100644
--- a/flang/lib/Frontend/CMakeLists.txt
+++ b/flang/lib/Frontend/CMakeLists.txt
@@ -38,6 +38,7 @@ add_flang_library(flangFrontend
   FIRTransforms
   HLFIRDialect
   HLFIRTransforms
+  FlangOpenMPTransforms
   MLIRTransforms
   MLIRBuiltinToLLVMIRTranslation
   MLIRLLVMToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CMakeLists.txt b/flang/lib/Optimizer/CMakeLists.txt
index 4a602162ed2b7..dd153ac33c0fb 100644
--- a/flang/lib/Optimizer/CMakeLists.txt
+++ b/flang/lib/Optimizer/CMakeLists.txt
@@ -5,3 +5,4 @@ add_subdirectory(HLFIR)
 add_subdirectory(Support)
 add_subdirectory(Transforms)
 add_subdirectory(Analysis)
+add_subdirectory(OpenMP)
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
new file mode 100644
index 0000000000000..74419327d76d0
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -0,0 +1,26 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+
+add_flang_library(FlangOpenMPTransforms
+  LowerWorkshare.cpp
+
+  DEPENDS
+  FIRDialect
+  FlangOpenMPPassesIncGen
+  ${dialect_libs}
+
+  LINK_LIBS
+  FIRAnalysis
+  FIRDialect
+  FIRBuilder
+  FIRDialectSupport
+  FIRSupport
+  FIRTransforms
+  HLFIRDialect
+  MLIRIR
+  ${dialect_libs}
+
+  LINK_COMPONENTS
+  AsmParser
+  AsmPrinter
+  Remarks
+)
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
new file mode 100644
index 0000000000000..40975552d1fe3
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -0,0 +1,259 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Lower omp workshare construct.
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/iterator_range.h"
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKSHARE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workshare"
+
+using namespace mlir;
+
+namespace flangomp {
+bool shouldUseWorkshareLowering(Operation *op) {
+  auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
+  if (!workshare)
+    return false;
+  return workshare->getParentOfType<omp::ParallelOp>();
+}
+} // namespace flangomp
+
+namespace {
+
+struct SingleRegion {
+  Block::iterator begin, end;
+};
+
+static bool isSupportedByFirAlloca(Type ty) {
+  return !isa<fir::ReferenceType>(ty);
+}
+
+static bool isSafeToParallelize(Operation *op) {
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface) {
+    return false;
+  }
+  interface.getEffects(effects);
+  if (effects.empty())
+    return true;
+
+  return false;
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+///   %a = fir.allocmem
+///   omp.wsloop {}
+///   fir.call Assign %b %a
+///   fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// omp.single {
+///   %a = fir.allocmem
+///   fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.wsloop {}
+/// omp.single {
+///   fir.call Assign %b %a_reloaded
+///   fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which need
+/// to be accessed in all threads in the closest omp.parallel
+///
+/// TODO currently we need to be able to access the encompassing omp.parallel so
+/// that we can allocate temporaries accessible by all threads outside of it.
+/// In case we do not find it, we fall back to converting the omp.workshare to
+/// omp.single.
+/// To better handle this we should probably enable yielding values out of an
+/// omp.single which will be supported by the omp runtime.
+void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+  assert(wsOp.getRegion().getBlocks().size() == 1);
+
+  Location loc = wsOp->getLoc();
+
+  omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
+  if (!parallelOp) {
+    wsOp.emitWarning("cannot handle workshare, converting to single");
+    Operation *terminator = wsOp.getRegion().front().getTerminator();
+    wsOp->getBlock()->getOperations().splice(
+        wsOp->getIterator(), wsOp.getRegion().front().getOperations());
+    terminator->erase();
+    return;
+  }
+
+  OpBuilder allocBuilder(parallelOp);
+  OpBuilder rootBuilder(wsOp);
+  IRMapping rootMapping;
+
+  omp::SingleOp singleOp = nullptr;
+
+  auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
+                              IRMapping singleMapping) {
+    if (auto reloaded = rootMapping.lookupOrNull(v))
+      return;
+    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
+    Type ty = v.getType();
+    Value alloc, reloaded;
+    if (isSupportedByFirAlloca(ty)) {
+      alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
+      singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+      reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
+    } else {
+      auto one = allocBuilder.create<LLVM::ConstantOp>(
+          loc, allocBuilder.getI32Type(), 1);
+      alloc =
+          allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
+      Value toStore = singleBuilder
+                          .create<UnrealizedConversionCastOp>(
+                              loc, llvmPtrTy, singleMapping.lookup(v))
+                          .getResult(0);
+      singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
+      reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
+      reloaded =
+          rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
+              .getResult(0);
+    }
+    rootMapping.map(v, reloaded);
+  };
+
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
+    IRMapping singleMapping = rootMapping;
+
+    for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
+      singleBuilder.clone(op, singleMapping);
+      if (isSafeToParallelize(&op)) {
+        rootBuilder.clone(op, rootMapping);
+      } else {
+        // Prepare reloaded values for results of operations that cannot be
+        // safely parallelized and which are used after the region `sr`
+        for (auto res : op.getResults()) {
+          for (auto &use : res.getUses()) {
+            Operation *user = use.getOwner();
+            while (user->getParentOp() != wsOp)
+              user = user->getParentOp();
+            if (!user->isBeforeInBlock(&*sr.end)) {
+              // We need to reload
+              mapReloadedValue(use.get(), singleBuilder, singleMapping);
+            }
+          }
+        }
+      }
+    }
+    singleBuilder.create<omp::TerminatorOp>(loc);
+  };
+
+  Block *wsBlock = &wsOp.getRegion().front();
+  assert(wsBlock->getTerminator()->getNumOperands() == 0);
+  Operation *terminator = wsBlock->getTerminator();
+
+  SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
+
+  auto it = wsBlock->begin();
+  auto getSingleRegion = [&]() {
+    if (&*it == terminator)
+      return false;
+    if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
+      regions.push_back(pop);
+      it++;
+      return true;
+    }
+    SingleRegion sr;
+    sr.begin = it;
+    while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
+      it++;
+    sr.end = it;
+    assert(sr.begin != sr.end);
+    regions.push_back(sr);
+    return true;
+  };
+  while (getSingleRegion())
+    ;
+
+  for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
+    bool isLast = i + 1 == regions.size();
+    if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
+      omp::SingleOperands singleOperands;
+      if (isLast)
+        singleOperands.nowait = rootBuilder.getUnitAttr();
+      singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+      OpBuilder singleBuilder(singleOp);
+      singleBuilder.createBlock(&singleOp.getRegion());
+      moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
+    } else {
+      rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
+      if (!isLast)
+        rootBuilder.create<omp::BarrierOp>(loc);
+    }
+  }
+
+  if (!wsOp.getNowait())
+    rootBuilder.create<omp::BarrierOp>(loc);
+
+  wsOp->erase();
+
+  return;
+}
+
+class LowerWorksharePass
+    : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
+public:
+  void runOnOperation() override {
+    SmallPtrSet<Operation *, 8> parents;
+    getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
+      Operation *isolatedParent =
+          wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
+      parents.insert(isolatedParent);
+
+      lowerWorkshare(wsOp);
+    });
+
+    // Do folding
+    for (Operation *isolatedParent : parents) {
+      RewritePatternSet patterns(&getContext());
+      GreedyRewriteConfig config;
+      // prevent the pattern driver form merging blocks
+      config.enableRegionSimplification =
+          mlir::GreedySimplifyRegionLevel::Disabled;
+      if (failed(applyPatternsAndFoldGreedily(isolatedParent,
+                                              std::move(patterns), config))) {
+        emitError(isolatedParent->getLoc(), "error in lower workshare\n");
+        signalPassFailure();
+      }
+    }
+  }
+};
+} // namespace
diff --git a/flang/test/HLFIR/bufferize-workshare.fir b/flang/test/HLFIR/bufferize-workshare.fir
new file mode 100644
index 0000000000000..86a2f031478dd
--- /dev/null
+++ b/flang/test/HLFIR/bufferize-workshare.fir
@@ -0,0 +1,58 @@
+// RUN: fir-opt --bufferize-hlfir %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @simple(
+// CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK:           omp.parallel {
+// CHECK:             omp.workshare {
+// CHECK:               %[[VAL_1:.*]] = arith.constant 42 : index
+// CHECK:               %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK:               %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_7:.*]] = arith.constant true
+// CHECK:               %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:               omp.wsloop {
+// CHECK:                 omp.loop_nest (%[[VAL_9:.*]]) : index = (%[[VAL_8]]) to (%[[VAL_1]]) inclusive step (%[[VAL_8]]) {
+// CHECK:                   %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_9]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                   %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
+// CHECK:                   %[[VAL_12:.*]] = arith.subi %[[VAL_11]], %[[VAL_2]] : i32
+// CHECK:                   %[[VAL_13:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                   hlfir.assign %[[VAL_12]] to %[[VAL_13]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                   omp.yield
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               %[[VAL_14:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_7]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_6]]#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               hlfir.assign %[[VAL_6]]#0 to %[[VAL_4]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_6]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @simple(%arg: !fir.ref<!fir.array<42xi32>>) {
+  omp.parallel {
+    omp.workshare {
+      %c42 = arith.constant 42 : index
+      %c1_i32 = arith.constant 1 : i32
+      %shape = fir.shape %c42 : (index) -> !fir.shape<1>
+      %array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+      %elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+      ^bb0(%i: index):
+        %ref = hlfir.designate %array#0 (%i) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+        %val = fir.load %ref : !fir.ref<i32>
+        %sub = arith.subi %val, %c1_i32 : i32
+        hlfir.yield_element %sub : i32
+      }
+      hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+      hlfir.destroy %elemental : !hlfir.expr<42xi32>
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
new file mode 100644
index 0000000000000..a8d36443f08bd
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -0,0 +1,81 @@
+// RUN: fir-opt --lower-workshare %s | FileCheck %s
+
+module {
+// CHECK-LABEL:   func.func @simple(
+// CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
+// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           omp.parallel {
+// CHECK:             omp.single {
+// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]]#0 : !fir.ref<!fir.array<42xi32>> to !llvm.ptr
+// CHECK:               llvm.store %[[VAL_9]], %[[VAL_5]] : !llvm.ptr, !llvm.ptr
+// CHECK:               %[[VAL_10:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               fir.store %[[VAL_11]]#0 to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_12:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.ptr
+// CHECK:             %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.ptr to !fir.ref<!fir.array<42xi32>>
+// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_2]] : i32
+// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_14]] (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                 omp.yield
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.barrier
+// CHECK:             omp.single nowait {
+// CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.barrier
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+  func.func @simple(%arg0: !fir.ref<!fir.array<42xi32>>) {
+    omp.parallel {
+      omp.workshare {
+        %c42 = arith.constant 42 : index
+        %c1_i32 = arith.constant 1 : i32
+        %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+        %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+        %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+        %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+        %true = arith.constant true
+        %c1 = arith.constant 1 : index
+        omp.wsloop {
+          omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+            %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+            %8 = fir.load %7 : !fir.ref<i32>
+            %9 = arith.subi %8, %c1_i32 : i32
+            %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+            hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+            omp.yield
+          }
+          omp.terminator
+        }
+        %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+        fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+        omp.terminator
+      }
+      omp.terminator
+    }
+    return
+  }
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workshare5.mlir b/flang/test/Transforms/OpenMP/lower-workshare5.mlir
new file mode 100644
index 0000000000000..177f8aa8f86c7
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare5.mlir
@@ -0,0 +1,42 @@
+// XFAIL: *
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// TODO we can lower these but we have no guarantee that the parent of
+// omp.workshare supports multi-block regions, thus we fail for now.
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+    ^bb1:
+      %c1 = arith.constant 1 : i32
+      cf.br ^bb3(%c1: i32)
+    ^bb3(%arg1: i32):
+      "test.test2"(%arg1) : (i32) -> ()
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+    ^bb1:
+      %c1 = arith.constant 1 : i32
+      cf.br ^bb3(%c1: i32)
+    ^bb2:
+      "test.test2"(%r) : (i32) -> ()
+      omp.terminator
+    ^bb3(%arg1: i32):
+      %r = "test.test2"(%arg1) : (i32) -> i32
+      cf.br ^bb2
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/flang/tools/bbc/CMakeLists.txt b/flang/tools/bbc/CMakeLists.txt
index 9410fd0056600..69316d4dc61de 100644
--- a/flang/tools/bbc/CMakeLists.txt
+++ b/flang/tools/bbc/CMakeLists.txt
@@ -25,6 +25,7 @@ FIRTransforms
 FIRBuilder
 HLFIRDialect
 HLFIRTransforms
+FlangOpenMPTransforms
 ${dialect_libs}
 ${extension_libs}
 MLIRAffineToStandard
diff --git a/flang/tools/fir-opt/CMakeLists.txt b/flang/tools/fir-opt/CMakeLists.txt
index 43679a9d53578..4c6dbf7d9c8c3 100644
--- a/flang/tools/fir-opt/CMakeLists.txt
+++ b/flang/tools/fir-opt/CMakeLists.txt
@@ -19,6 +19,7 @@ target_link_libraries(fir-opt PRIVATE
   FIRCodeGen
   HLFIRDialect
   HLFIRTransforms
+  FlangOpenMPTransforms
   FIRAnalysis
   ${test_libs}
   ${dialect_libs}
diff --git a/flang/tools/fir-opt/fir-opt.cpp b/flang/tools/fir-opt/fir-opt.cpp
index 1846c1b317848..f75fba27c68f0 100644
--- a/flang/tools/fir-opt/fir-opt.cpp
+++ b/flang/tools/fir-opt/fir-opt.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
 #include "flang/Optimizer/CodeGen/CodeGen.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
 #include "flang/Optimizer/Support/InitFIR.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 
@@ -34,6 +35,7 @@ int main(int argc, char **argv) {
   fir::registerOptCodeGenPasses();
   fir::registerOptTransformPasses();
   hlfir::registerHLFIRPasses();
+  flangomp::registerFlangOpenMPPasses();
 #ifdef FLANG_INCLUDE_TESTS
   fir::test::registerTestFIRAliasAnalysisPass();
   mlir::registerSideEffectTestPasses();
diff --git a/flang/tools/tco/CMakeLists.txt b/flang/tools/tco/CMakeLists.txt
index 808219ac361f2..698a398547c77 100644
--- a/flang/tools/tco/CMakeLists.txt
+++ b/flang/tools/tco/CMakeLists.txt
@@ -17,6 +17,7 @@ target_link_libraries(tco PRIVATE
   FIRBuilder
   HLFIRDialect
   HLFIRTransforms
+  FlangOpenMPTransforms
   ${dialect_libs}
   ${extension_libs}
   MLIRIR

>From e33a4bfc0505eff8ce24066128f9069fd67827b8 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:41:09 +0900
Subject: [PATCH 08/23] Change to workshare loop wrapper op

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 24 ++++++++++++-------
 flang/test/HLFIR/bufferize-workshare.fir      |  4 ++--
 .../Transforms/OpenMP/lower-workshare.mlir    |  5 ++--
 3 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 40975552d1fe3..cb342b60de4e8 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -21,6 +21,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/iterator_range.h"
 
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
 #include <variant>
 
 namespace flangomp {
@@ -73,7 +74,7 @@ static bool isSafeToParallelize(Operation *op) {
 ///
 /// omp.workshare {
 ///   %a = fir.allocmem
-///   omp.wsloop {}
+///   omp.workshare_loop_wrapper {}
 ///   fir.call Assign %b %a
 ///   fir.freemem %a
 /// }
@@ -85,7 +86,7 @@ static bool isSafeToParallelize(Operation *op) {
 ///   fir.store %a %tmp
 /// }
 /// %a_reloaded = fir.load %tmp
-/// omp.wsloop {}
+/// omp.workshare_loop_wrapper {}
 /// omp.single {
 ///   fir.call Assign %b %a_reloaded
 ///   fir.freemem %a_reloaded
@@ -180,20 +181,20 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
   assert(wsBlock->getTerminator()->getNumOperands() == 0);
   Operation *terminator = wsBlock->getTerminator();
 
-  SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
+  SmallVector<std::variant<SingleRegion, omp::WorkshareLoopWrapperOp>> regions;
 
   auto it = wsBlock->begin();
   auto getSingleRegion = [&]() {
     if (&*it == terminator)
       return false;
-    if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
+    if (auto pop = dyn_cast<omp::WorkshareLoopWrapperOp>(&*it)) {
       regions.push_back(pop);
       it++;
       return true;
     }
     SingleRegion sr;
     sr.begin = it;
-    while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
+    while (&*it != terminator && !isa<omp::WorkshareLoopWrapperOp>(&*it))
       it++;
     sr.end = it;
     assert(sr.begin != sr.end);
@@ -214,9 +215,16 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
       singleBuilder.createBlock(&singleOp.getRegion());
       moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
     } else {
-      rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
-      if (!isLast)
-        rootBuilder.create<omp::BarrierOp>(loc);
+      omp::WsloopOperands wsloopOperands;
+      if (isLast)
+        wsloopOperands.nowait = rootBuilder.getUnitAttr();
+      auto wsloop =
+          rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
+      auto wslw = std::get<omp::WorkshareLoopWrapperOp>(loopOrSingle);
+      auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
+          rootBuilder.clone(*wslw, rootMapping));
+      wsloop.getRegion().takeBody(clonedWslw.getRegion());
+      clonedWslw->erase();
     }
   }
 
diff --git a/flang/test/HLFIR/bufferize-workshare.fir b/flang/test/HLFIR/bufferize-workshare.fir
index 86a2f031478dd..33b368a62eaab 100644
--- a/flang/test/HLFIR/bufferize-workshare.fir
+++ b/flang/test/HLFIR/bufferize-workshare.fir
@@ -12,7 +12,7 @@
 // CHECK:               %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:               %[[VAL_7:.*]] = arith.constant true
 // CHECK:               %[[VAL_8:.*]] = arith.constant 1 : index
-// CHECK:               omp.wsloop {
+// CHECK:               "omp.workshare_loop_wrapper"() ({
 // CHECK:                 omp.loop_nest (%[[VAL_9:.*]]) : index = (%[[VAL_8]]) to (%[[VAL_1]]) inclusive step (%[[VAL_8]]) {
 // CHECK:                   %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_9]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
 // CHECK:                   %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
@@ -22,7 +22,7 @@
 // CHECK:                   omp.yield
 // CHECK:                 }
 // CHECK:                 omp.terminator
-// CHECK:               }
+// CHECK:               }) : () -> ()
 // CHECK:               %[[VAL_14:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
 // CHECK:               %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_7]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
 // CHECK:               %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_6]]#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index a8d36443f08bd..cb5791d35916a 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -34,7 +34,6 @@ module {
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             omp.barrier
 // CHECK:             omp.single nowait {
 // CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
 // CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
@@ -56,7 +55,7 @@ module {
         %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
         %true = arith.constant true
         %c1 = arith.constant 1 : index
-        omp.wsloop {
+        "omp.workshare_loop_wrapper"() ({
           omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
             %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
             %8 = fir.load %7 : !fir.ref<i32>
@@ -66,7 +65,7 @@ module {
             omp.yield
           }
           omp.terminator
-        }
+        }) : () -> ()
         %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
         %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
         %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>

>From 08ad589723f61af9108ee5c91c184733f7294dac Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:47:27 +0900
Subject: [PATCH 09/23] Move single op declaration

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index cb342b60de4e8..2322d2acbc013 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -120,8 +120,6 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
   OpBuilder rootBuilder(wsOp);
   IRMapping rootMapping;
 
-  omp::SingleOp singleOp = nullptr;
-
   auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
                               IRMapping singleMapping) {
     if (auto reloaded = rootMapping.lookupOrNull(v))
@@ -210,7 +208,8 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
       omp::SingleOperands singleOperands;
       if (isLast)
         singleOperands.nowait = rootBuilder.getUnitAttr();
-      singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+      omp::SingleOp singleOp =
+          rootBuilder.create<omp::SingleOp>(loc, singleOperands);
       OpBuilder singleBuilder(singleOp);
       singleBuilder.createBlock(&singleOp.getRegion());
       moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);

>From 6a411b99fde5ca5ce13bfba38364491adc75a8a3 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 17:13:58 +0900
Subject: [PATCH 10/23] Schedule pass properly

---
 flang/include/flang/Tools/CLOptions.inc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 594369fc2ffe5..2edaf0c6a0ae8 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,7 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  pm.addPass(flangomp::createLowerWorkshare());
+  addNestedPassToAllTopLevelOperations(pm, flangomp::createLowerWorkshare);
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed

>From 0429ba8e1bd89f7b564ae4feb7401b91859952d5 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 00:30:40 +0900
Subject: [PATCH 11/23] Correctly handle nested nested loop nests to be
 parallelized by workshare

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 256 ++++++++++--------
 1 file changed, 138 insertions(+), 118 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 2322d2acbc013..8e79d1401c01c 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -19,9 +19,14 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/iterator_range.h"
 
+#include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/Visitors.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <variant>
 
 namespace flangomp {
@@ -52,90 +57,40 @@ static bool isSupportedByFirAlloca(Type ty) {
   return !isa<fir::ReferenceType>(ty);
 }
 
-static bool isSafeToParallelize(Operation *op) {
-  if (isa<fir::DeclareOp>(op))
-    return true;
-
-  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
-  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
-  if (!interface) {
-    return false;
-  }
-  interface.getEffects(effects);
-  if (effects.empty())
-    return true;
-
-  return false;
+static bool mustParallelizeOp(Operation *op) {
+  return op
+      ->walk(
+          [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
+      .wasInterrupted();
 }
 
-/// Lowers workshare to a sequence of single-thread regions and parallel loops
-///
-/// For example:
-///
-/// omp.workshare {
-///   %a = fir.allocmem
-///   omp.workshare_loop_wrapper {}
-///   fir.call Assign %b %a
-///   fir.freemem %a
-/// }
-///
-/// becomes
-///
-/// omp.single {
-///   %a = fir.allocmem
-///   fir.store %a %tmp
-/// }
-/// %a_reloaded = fir.load %tmp
-/// omp.workshare_loop_wrapper {}
-/// omp.single {
-///   fir.call Assign %b %a_reloaded
-///   fir.freemem %a_reloaded
-/// }
-///
-/// Note that we allocate temporary memory for values in omp.single's which need
-/// to be accessed in all threads in the closest omp.parallel
-///
-/// TODO currently we need to be able to access the encompassing omp.parallel so
-/// that we can allocate temporaries accessible by all threads outside of it.
-/// In case we do not find it, we fall back to converting the omp.workshare to
-/// omp.single.
-/// To better handle this we should probably enable yielding values out of an
-/// omp.single which will be supported by the omp runtime.
-void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
-  assert(wsOp.getRegion().getBlocks().size() == 1);
-
-  Location loc = wsOp->getLoc();
+static bool isSafeToParallelize(Operation *op) {
+  return isa<fir::DeclareOp>(op) || isPure(op);
+}
 
-  omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
-  if (!parallelOp) {
-    wsOp.emitWarning("cannot handle workshare, converting to single");
-    Operation *terminator = wsOp.getRegion().front().getTerminator();
-    wsOp->getBlock()->getOperations().splice(
-        wsOp->getIterator(), wsOp.getRegion().front().getOperations());
-    terminator->erase();
-    return;
-  }
-
-  OpBuilder allocBuilder(parallelOp);
-  OpBuilder rootBuilder(wsOp);
-  IRMapping rootMapping;
+static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
+                              IRMapping &rootMapping, Location loc) {
+  Operation *parentOp = sourceRegion.getParentOp();
+  OpBuilder rootBuilder(sourceRegion.getContext());
 
+  // TODO need to copyprivate the alloca's
   auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
                               IRMapping singleMapping) {
+    OpBuilder allocaBuilder(&targetRegion.front().front());
     if (auto reloaded = rootMapping.lookupOrNull(v))
       return;
-    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
+    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
     Value alloc, reloaded;
     if (isSupportedByFirAlloca(ty)) {
-      alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
+      alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
       singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
       reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
     } else {
-      auto one = allocBuilder.create<LLVM::ConstantOp>(
-          loc, allocBuilder.getI32Type(), 1);
+      auto one = allocaBuilder.create<LLVM::ConstantOp>(
+          loc, allocaBuilder.getI32Type(), 1);
       alloc =
-          allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
+          allocaBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
       Value toStore = singleBuilder
                           .create<UnrealizedConversionCastOp>(
                               loc, llvmPtrTy, singleMapping.lookup(v))
@@ -162,9 +117,10 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
         for (auto res : op.getResults()) {
           for (auto &use : res.getUses()) {
             Operation *user = use.getOwner();
-            while (user->getParentOp() != wsOp)
+            while (user->getParentOp() != parentOp)
               user = user->getParentOp();
-            if (!user->isBeforeInBlock(&*sr.end)) {
+            if (!(user->isBeforeInBlock(&*sr.end) &&
+                  sr.begin->isBeforeInBlock(user))) {
               // We need to reload
               mapReloadedValue(use.get(), singleBuilder, singleMapping);
             }
@@ -175,61 +131,125 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
     singleBuilder.create<omp::TerminatorOp>(loc);
   };
 
-  Block *wsBlock = &wsOp.getRegion().front();
-  assert(wsBlock->getTerminator()->getNumOperands() == 0);
-  Operation *terminator = wsBlock->getTerminator();
+  // TODO Need to handle these (clone them) in dominator tree order
+  for (Block &block : sourceRegion) {
+    rootBuilder.createBlock(
+        &targetRegion, {}, block.getArgumentTypes(),
+        llvm::map_to_vector(block.getArguments(),
+                            [](BlockArgument arg) { return arg.getLoc(); }));
+    Operation *terminator = block.getTerminator();
 
-  SmallVector<std::variant<SingleRegion, omp::WorkshareLoopWrapperOp>> regions;
+    SmallVector<std::variant<SingleRegion, Operation *>> regions;
 
-  auto it = wsBlock->begin();
-  auto getSingleRegion = [&]() {
-    if (&*it == terminator)
-      return false;
-    if (auto pop = dyn_cast<omp::WorkshareLoopWrapperOp>(&*it)) {
-      regions.push_back(pop);
-      it++;
+    auto it = block.begin();
+    auto getOneRegion = [&]() {
+      if (&*it == terminator)
+        return false;
+      if (mustParallelizeOp(&*it)) {
+        regions.push_back(&*it);
+        it++;
+        return true;
+      }
+      SingleRegion sr;
+      sr.begin = it;
+      while (&*it != terminator && !mustParallelizeOp(&*it))
+        it++;
+      sr.end = it;
+      assert(sr.begin != sr.end);
+      regions.push_back(sr);
       return true;
+    };
+    while (getOneRegion())
+      ;
+
+    for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
+      bool isLast = i + 1 == regions.size();
+      if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+        omp::SingleOperands singleOperands;
+        if (isLast)
+          singleOperands.nowait = rootBuilder.getUnitAttr();
+        omp::SingleOp singleOp =
+            rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+        OpBuilder singleBuilder(singleOp);
+        singleBuilder.createBlock(&singleOp.getRegion());
+        moveToSingle(std::get<SingleRegion>(opOrSingle), singleBuilder);
+      } else {
+        auto op = std::get<Operation *>(opOrSingle);
+        if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {
+          omp::WsloopOperands wsloopOperands;
+          if (isLast)
+            wsloopOperands.nowait = rootBuilder.getUnitAttr();
+          auto wsloop =
+              rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
+          auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
+              rootBuilder.clone(*wslw, rootMapping));
+          wsloop.getRegion().takeBody(clonedWslw.getRegion());
+          clonedWslw->erase();
+        } else {
+          assert(mustParallelizeOp(op));
+          Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
+          for (auto [region, clonedRegion] :
+               llvm::zip(op->getRegions(), cloned->getRegions()))
+            parallelizeRegion(region, clonedRegion, rootMapping, loc);
+        }
+      }
     }
-    SingleRegion sr;
-    sr.begin = it;
-    while (&*it != terminator && !isa<omp::WorkshareLoopWrapperOp>(&*it))
-      it++;
-    sr.end = it;
-    assert(sr.begin != sr.end);
-    regions.push_back(sr);
-    return true;
-  };
-  while (getSingleRegion())
-    ;
-
-  for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
-    bool isLast = i + 1 == regions.size();
-    if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
-      omp::SingleOperands singleOperands;
-      if (isLast)
-        singleOperands.nowait = rootBuilder.getUnitAttr();
-      omp::SingleOp singleOp =
-          rootBuilder.create<omp::SingleOp>(loc, singleOperands);
-      OpBuilder singleBuilder(singleOp);
-      singleBuilder.createBlock(&singleOp.getRegion());
-      moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
-    } else {
-      omp::WsloopOperands wsloopOperands;
-      if (isLast)
-        wsloopOperands.nowait = rootBuilder.getUnitAttr();
-      auto wsloop =
-          rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
-      auto wslw = std::get<omp::WorkshareLoopWrapperOp>(loopOrSingle);
-      auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
-          rootBuilder.clone(*wslw, rootMapping));
-      wsloop.getRegion().takeBody(clonedWslw.getRegion());
-      clonedWslw->erase();
-    }
+
+    rootBuilder.clone(*block.getTerminator(), rootMapping);
   }
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+///   %a = fir.allocmem
+///   omp.workshare_loop_wrapper {}
+///   fir.call Assign %b %a
+///   fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// omp.single {
+///   %a = fir.allocmem
+///   fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.workshare_loop_wrapper {}
+/// omp.single {
+///   fir.call Assign %b %a_reloaded
+///   fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which need
+/// to be accessed in all threads in the closest omp.parallel
+void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+  Location loc = wsOp->getLoc();
+  IRMapping rootMapping;
+
+  OpBuilder rootBuilder(wsOp);
+
+  // TODO We need something like an scf;execute here, but that is not registered
+  // so using fir.if for now but it looks like it does not support multiple
+  // blocks so it doesnt work for multi block case...
+  auto ifOp = rootBuilder.create<fir::IfOp>(
+      loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
+  ifOp.getThenRegion().front().erase();
+
+  parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);
+
+  Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
+  assert(isa<omp::TerminatorOp>(terminatorOp));
+  OpBuilder termBuilder(terminatorOp);
 
   if (!wsOp.getNowait())
-    rootBuilder.create<omp::BarrierOp>(loc);
+    termBuilder.create<omp::BarrierOp>(loc);
+
+  termBuilder.create<fir::ResultOp>(loc, ValueRange());
 
+  terminatorOp->erase();
   wsOp->erase();
 
   return;

>From 73300b669a449734e276e93017248a163abacffe Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 00:33:57 +0900
Subject: [PATCH 12/23] Leave comments for shouldUseWorkshareLowering

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 21 +++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 8e79d1401c01c..40dae0fd848ef 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -40,10 +40,23 @@ using namespace mlir;
 
 namespace flangomp {
 bool shouldUseWorkshareLowering(Operation *op) {
-  auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
-  if (!workshare)
-    return false;
-  return workshare->getParentOfType<omp::ParallelOp>();
+  // TODO this is insufficient, as we could have
+  // omp.parallel {
+  //   omp.workshare {
+  //     omp.parallel {
+  //       hlfir.elemental {}
+  //
+  // Then this hlfir.elemental shall _not_ use the lowering for workshare
+  //
+  // Standard says:
+  //   For a parallel construct, the construct is a unit of work with respect to
+  //   the workshare construct. The statements contained in the parallel
+  //   construct are executed by a new thread team.
+  //
+  // TODO similarly for single, critical, etc. Need to think through the
+  // patterns and implement this function.
+  //
+  return op->getParentOfType<omp::WorkshareOp>();
 }
 } // namespace flangomp
 

>From 11c734d54ea7dc9deabdf31735a517c0c8362717 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 13:14:38 +0900
Subject: [PATCH 13/23] Use copyprivate to scatter val from omp.single

TODO still need to implement copy function
TODO transitive check for usage outside of omp.single not imiplemented yet
---
 flang/include/flang/Optimizer/OpenMP/Passes.h |   2 +
 .../include/flang/Optimizer/OpenMP/Passes.td  |   3 +-
 flang/include/flang/Tools/CLOptions.inc       |   2 +-
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 138 ++++++++++++++----
 4 files changed, 111 insertions(+), 34 deletions(-)

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.h b/flang/include/flang/Optimizer/OpenMP/Passes.h
index 95a05b3005073..af6fc6283521d 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.h
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.h
@@ -14,8 +14,10 @@
 #define FORTRAN_OPTIMIZER_OPENMP_PASSES_H
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassRegistry.h"
+
 #include <memory>
 
 namespace flangomp {
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 6f636ec1df616..b22d22859a929 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -11,7 +11,8 @@
 
 include "mlir/Pass/PassBase.td"
 
-def LowerWorkshare : Pass<"lower-workshare"> {
+// Needs to be scheduled on Module as we create functions in it
+def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 2edaf0c6a0ae8..594369fc2ffe5 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,7 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  addNestedPassToAllTopLevelOperations(pm, flangomp::createLowerWorkshare);
+  pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 40dae0fd848ef..950737fccada7 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -8,25 +8,27 @@
 // Lower omp workshare construct.
 //===----------------------------------------------------------------------===//
 
-#include "flang/Optimizer/Dialect/FIROps.h"
-#include "flang/Optimizer/Dialect/FIRType.h"
-#include "flang/Optimizer/OpenMP/Passes.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVectorExtras.h"
-#include "llvm/ADT/iterator_range.h"
-
+#include <flang/Optimizer/Builder/FIRBuilder.h>
+#include <flang/Optimizer/Dialect/FIROps.h>
+#include <flang/Optimizer/Dialect/FIRType.h>
+#include <flang/Optimizer/HLFIR/HLFIROps.h>
+#include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/ADT/SmallVectorExtras.h>
+#include <llvm/ADT/iterator_range.h>
+#include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
-#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/PatternMatch.h>
 #include <mlir/IR/Visitors.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
+
 #include <variant>
 
 namespace flangomp {
@@ -71,6 +73,8 @@ static bool isSupportedByFirAlloca(Type ty) {
 }
 
 static bool mustParallelizeOp(Operation *op) {
+  // TODO as in shouldUseWorkshareLowering we be careful not to pick up
+  // workshare_loop_wrapper in nested omp.parallel ops
   return op
       ->walk(
           [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
@@ -78,7 +82,33 @@ static bool mustParallelizeOp(Operation *op) {
 }
 
 static bool isSafeToParallelize(Operation *op) {
-  return isa<fir::DeclareOp>(op) || isPure(op);
+  return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
+         isMemoryEffectFree(op);
+}
+
+static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
+                                         fir::FirOpBuilder builder) {
+  mlir::ModuleOp module = builder.getModule();
+  mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
+
+  std::string copyFuncName =
+      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+
+  if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
+    return decl;
+  // create function
+  mlir::OpBuilder::InsertionGuard guard(builder);
+  mlir::OpBuilder modBuilder(module.getBodyRegion());
+  llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
+  auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
+  mlir::func::FuncOp funcOp =
+      modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
+  funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
+  builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
+                      {loc, loc});
+  builder.setInsertionPointToStart(&funcOp.getRegion().back());
+  builder.create<mlir::func::ReturnOp>(loc);
+  return funcOp;
 }
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
@@ -86,19 +116,23 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
   Operation *parentOp = sourceRegion.getParentOp();
   OpBuilder rootBuilder(sourceRegion.getContext());
 
+  ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
+  OpBuilder copyFuncBuilder(m.getBodyRegion());
+  fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
+
   // TODO need to copyprivate the alloca's
-  auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
-                              IRMapping singleMapping) {
-    OpBuilder allocaBuilder(&targetRegion.front().front());
+  auto mapReloadedValue =
+      [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
+          OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
     if (auto reloaded = rootMapping.lookupOrNull(v))
-      return;
+      return nullptr;
     Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
     Value alloc, reloaded;
     if (isSupportedByFirAlloca(ty)) {
       alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
       singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
-      reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
+      reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
     } else {
       auto one = allocaBuilder.create<LLVM::ConstantOp>(
           loc, allocaBuilder.getI32Type(), 1);
@@ -109,21 +143,25 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               loc, llvmPtrTy, singleMapping.lookup(v))
                           .getResult(0);
       singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
-      reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
+      reloaded = parallelBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
       reloaded =
-          rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
+          parallelBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
               .getResult(0);
     }
     rootMapping.map(v, reloaded);
+    return alloc;
   };
 
-  auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
+                          OpBuilder singleBuilder,
+                          OpBuilder parallelBuilder) -> SmallVector<Value> {
     IRMapping singleMapping = rootMapping;
+    SmallVector<Value> copyPrivate;
 
     for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
       singleBuilder.clone(op, singleMapping);
       if (isSafeToParallelize(&op)) {
-        rootBuilder.clone(op, rootMapping);
+        parallelBuilder.clone(op, rootMapping);
       } else {
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
@@ -132,16 +170,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
             Operation *user = use.getOwner();
             while (user->getParentOp() != parentOp)
               user = user->getParentOp();
-            if (!(user->isBeforeInBlock(&*sr.end) &&
-                  sr.begin->isBeforeInBlock(user))) {
-              // We need to reload
-              mapReloadedValue(use.get(), singleBuilder, singleMapping);
+            // TODO we need to look at transitively used vals
+            if (true || !(user->isBeforeInBlock(&*sr.end) &&
+                          sr.begin->isBeforeInBlock(user))) {
+              auto alloc =
+                  mapReloadedValue(use.get(), allocaBuilder, singleBuilder,
+                                   parallelBuilder, singleMapping);
+              if (alloc)
+                copyPrivate.push_back(alloc);
             }
           }
         }
       }
     }
     singleBuilder.create<omp::TerminatorOp>(loc);
+    return copyPrivate;
   };
 
   // TODO Need to handle these (clone them) in dominator tree order
@@ -178,14 +221,45 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
       bool isLast = i + 1 == regions.size();
       if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+        OpBuilder singleBuilder(sourceRegion.getContext());
+        Block *singleBlock = new Block();
+        singleBuilder.setInsertionPointToStart(singleBlock);
+
+        OpBuilder allocaBuilder(sourceRegion.getContext());
+        Block *allocaBlock = new Block();
+        allocaBuilder.setInsertionPointToStart(allocaBlock);
+
+        OpBuilder parallelBuilder(sourceRegion.getContext());
+        Block *parallelBlock = new Block();
+        parallelBuilder.setInsertionPointToStart(parallelBlock);
+
         omp::SingleOperands singleOperands;
         if (isLast)
           singleOperands.nowait = rootBuilder.getUnitAttr();
+        auto insPtAtSingle = rootBuilder.saveInsertionPoint();
+        singleOperands.copyprivateVars =
+            moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
+                         singleBuilder, parallelBuilder);
+        for (auto var : singleOperands.copyprivateVars) {
+          Type ty;
+          if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
+            ty = firAlloca.getAllocatedType();
+          } else {
+            llvm_unreachable("unexpected");
+          }
+          mlir::func::FuncOp funcOp =
+              createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
+          singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
+        }
         omp::SingleOp singleOp =
             rootBuilder.create<omp::SingleOp>(loc, singleOperands);
-        OpBuilder singleBuilder(singleOp);
-        singleBuilder.createBlock(&singleOp.getRegion());
-        moveToSingle(std::get<SingleRegion>(opOrSingle), singleBuilder);
+        singleOp.getRegion().push_back(singleBlock);
+        rootBuilder.getInsertionBlock()->getOperations().splice(
+            rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
+        targetRegion.front().getOperations().splice(
+            singleOp->getIterator(), allocaBlock->getOperations());
+        delete allocaBlock;
+        delete parallelBlock;
       } else {
         auto op = std::get<Operation *>(opOrSingle);
         if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {

>From e3bce2f99b63d1a3f67e9e6f262a37dd1261b303 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 13:39:49 +0900
Subject: [PATCH 14/23] Transitively check for users outisde of single op

TODO need to implement copy func
TODO need to hoist allocas outside of single regions
---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 51 ++++++++++++++-----
 1 file changed, 37 insertions(+), 14 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 950737fccada7..2e88d852ff2cb 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -111,6 +111,38 @@ static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
   return funcOp;
 }
 
+static bool isUserOutsideSR(Operation *user, Operation *parentOp,
+                            SingleRegion sr) {
+  while (user->getParentOp() != parentOp)
+    user = user->getParentOp();
+  return sr.begin->getBlock() != user->getBlock() ||
+         !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
+}
+
+static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
+  Block *srBlock = sr.begin->getBlock();
+  Operation *parentOp = srBlock->getParentOp();
+
+  for (auto &use : v.getUses()) {
+    Operation *user = use.getOwner();
+    if (isUserOutsideSR(user, parentOp, sr))
+      return true;
+
+    // Results of nested users cannot be used outside of the SR
+    if (user->getBlock() != srBlock)
+      continue;
+
+    // A non-safe to parallelize operation will be handled separately
+    if (!isSafeToParallelize(user))
+      continue;
+
+    for (auto res : user->getResults())
+      if (isTransitivelyUsedOutside(res, sr))
+        return true;
+  }
+  return false;
+}
+
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
   Operation *parentOp = sourceRegion.getParentOp();
@@ -166,19 +198,11 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
         for (auto res : op.getResults()) {
-          for (auto &use : res.getUses()) {
-            Operation *user = use.getOwner();
-            while (user->getParentOp() != parentOp)
-              user = user->getParentOp();
-            // TODO we need to look at transitively used vals
-            if (true || !(user->isBeforeInBlock(&*sr.end) &&
-                          sr.begin->isBeforeInBlock(user))) {
-              auto alloc =
-                  mapReloadedValue(use.get(), allocaBuilder, singleBuilder,
-                                   parallelBuilder, singleMapping);
-              if (alloc)
-                copyPrivate.push_back(alloc);
-            }
+          if (isTransitivelyUsedOutside(res, sr)) {
+            auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
+                                          parallelBuilder, singleMapping);
+            if (alloc)
+              copyPrivate.push_back(alloc);
           }
         }
       }
@@ -236,7 +260,6 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         omp::SingleOperands singleOperands;
         if (isLast)
           singleOperands.nowait = rootBuilder.getUnitAttr();
-        auto insPtAtSingle = rootBuilder.saveInsertionPoint();
         singleOperands.copyprivateVars =
             moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
                          singleBuilder, parallelBuilder);

>From e5a577b3dee837b72ee29e42ee1c13142442d5ae Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 14:30:43 +0900
Subject: [PATCH 15/23] Add tests

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp |  25 +-
 .../Transforms/OpenMP/lower-workshare.mlir    | 230 +++++++++++++-----
 2 files changed, 188 insertions(+), 67 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 2e88d852ff2cb..30af2556cf4ca 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -18,6 +18,7 @@
 #include <llvm/ADT/iterator_range.h>
 #include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
 #include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
 #include <mlir/IR/BuiltinOps.h>
@@ -75,6 +76,14 @@ static bool isSupportedByFirAlloca(Type ty) {
 static bool mustParallelizeOp(Operation *op) {
   // TODO as in shouldUseWorkshareLowering we be careful not to pick up
   // workshare_loop_wrapper in nested omp.parallel ops
+  //
+  // e.g.
+  //
+  // omp.parallel {
+  //   omp.workshare {
+  //     omp.parallel {
+  //       omp.workshare {
+  //         omp.workshare_loop_wrapper {}
   return op
       ->walk(
           [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
@@ -89,10 +98,14 @@ static bool isSafeToParallelize(Operation *op) {
 static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
                                          fir::FirOpBuilder builder) {
   mlir::ModuleOp module = builder.getModule();
-  mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
-
-  std::string copyFuncName =
-      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+  std::string copyFuncName;
+  if (auto rt = dyn_cast<fir::ReferenceType>(varType)) {
+    mlir::Type eleTy = rt.getEleTy();
+    copyFuncName =
+        fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+  } else {
+    copyFuncName = "_workshare_copy_llvm_ptr";
+  }
 
   if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
     return decl;
@@ -145,9 +158,7 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
-  Operation *parentOp = sourceRegion.getParentOp();
   OpBuilder rootBuilder(sourceRegion.getContext());
-
   ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
   OpBuilder copyFuncBuilder(m.getBodyRegion());
   fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
@@ -268,7 +279,7 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
             ty = firAlloca.getAllocatedType();
           } else {
-            llvm_unreachable("unexpected");
+            ty = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
           }
           mlir::func::FuncOp funcOp =
               createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index cb5791d35916a..19123e71cacf6 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -1,80 +1,190 @@
-// RUN: fir-opt --lower-workshare %s | FileCheck %s
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
-module {
-// CHECK-LABEL:   func.func @simple(
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.parallel {
+    omp.workshare {
+      %c42 = arith.constant 42 : index
+      %c1_i32 = arith.constant 1 : i32
+      %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+      %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+      %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+      %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+      %true = arith.constant true
+      %c1 = arith.constant 1 : index
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+          %8 = fir.load %7 : !fir.ref<i32>
+          %9 = arith.subi %8, %c1_i32 : i32
+          %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+          hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+      %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+      fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+
+// -----
+
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.workshare {
+    %c1_i32 = arith.constant 1 : i32
+    %alloc = fir.alloca i32
+    fir.store %c1_i32 to %alloc : !fir.ref<i32>
+    %c42 = arith.constant 42 : index
+    %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+    %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+    %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+    %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+    %true = arith.constant true
+    %c1 = arith.constant 1 : index
+    "omp.workshare_loop_wrapper"() ({
+      omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+        %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+        %8 = fir.load %7 : !fir.ref<i32>
+        %ld = fir.load %alloc : !fir.ref<i32>
+        %n8 = arith.subi %8, %ld : i32
+        %9 = arith.subi %n8, %c1_i32 : i32
+        %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+        hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+        omp.yield
+      }
+      omp.terminator
+    }) : () -> ()
+    %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    "test.test1"(%alloc) : (!fir.ref<i32>) -> ()
+    hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+    fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+    omp.terminator
+  }
+  return
+}
+
+
+// CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
 // CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK:           %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
-// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           %[[VAL_4:.*]] = arith.constant true
 // CHECK:           omp.parallel {
-// CHECK:             omp.single {
-// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]]#0 : !fir.ref<!fir.array<42xi32>> to !llvm.ptr
-// CHECK:               llvm.store %[[VAL_9]], %[[VAL_5]] : !llvm.ptr, !llvm.ptr
-// CHECK:               %[[VAL_10:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               fir.store %[[VAL_11]]#0 to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             fir.if %[[VAL_4]] {
+// CHECK:               %[[VAL_5:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:                 %[[VAL_6:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:                 %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_6]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:                 %[[VAL_8:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:                 fir.store %[[VAL_8]] to %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:                 %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_6]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               omp.wsloop {
+// CHECK:                 omp.loop_nest (%[[VAL_14:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                   %[[VAL_15:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_14]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                   %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
+// CHECK:                   %[[VAL_17:.*]] = arith.subi %[[VAL_16]], %[[VAL_2]] : i32
+// CHECK:                   %[[VAL_18:.*]] = hlfir.designate %[[VAL_13]]#0 (%[[VAL_14]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                   hlfir.assign %[[VAL_17]] to %[[VAL_18]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                   omp.yield
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.single nowait {
+// CHECK:                 hlfir.assign %[[VAL_13]]#0 to %[[VAL_11]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:                 fir.freemem %[[VAL_13]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.barrier
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_workshare_copy_llvm_ptr(
+// CHECK-SAME:                                                %[[VAL_0:.*]]: !llvm.ptr,
+// CHECK-SAME:                                                %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @wsfunc(
+// CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_5:.*]] = arith.constant true
+// CHECK:           fir.if %[[VAL_5]] {
+// CHECK:             %[[VAL_6:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
+// CHECK:             %[[VAL_7:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_6]] -> @_workshare_copy_llvm_ptr : !llvm.ptr, %[[VAL_7]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               %[[VAL_8:.*]] = fir.alloca i32
+// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !fir.ref<i32> to !llvm.ptr
+// CHECK:               llvm.store %[[VAL_9]], %[[VAL_6]] : !llvm.ptr, !llvm.ptr
+// CHECK:               fir.store %[[VAL_3]] to %[[VAL_8]] : !fir.ref<i32>
+// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_12:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_12]] to %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             %[[VAL_12:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.ptr
-// CHECK:             %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.ptr to !fir.ref<!fir.array<42xi32>>
-// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_14:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr -> !llvm.ptr
+// CHECK:             %[[VAL_15:.*]] = builtin.unrealized_conversion_cast %[[VAL_14]] : !llvm.ptr to !fir.ref<i32>
+// CHECK:             %[[VAL_16:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_17:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_16]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_18]](%[[VAL_16]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_2]] : i32
-// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_14]] (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.loop_nest (%[[VAL_20:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_17]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_23:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_24:.*]] = arith.subi %[[VAL_22]], %[[VAL_23]] : i32
+// CHECK:                 %[[VAL_25:.*]] = arith.subi %[[VAL_24]], %[[VAL_3]] : i32
+// CHECK:                 %[[VAL_26:.*]] = hlfir.designate %[[VAL_19]]#0 (%[[VAL_20]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_25]] to %[[VAL_26]] temporary_lhs : i32, !fir.ref<i32>
 // CHECK:                 omp.yield
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.single nowait {
-// CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
+// CHECK:               "test.test1"(%[[VAL_15]]) : (!fir.ref<i32>) -> ()
+// CHECK:               hlfir.assign %[[VAL_19]]#0 to %[[VAL_17]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_19]]#0 : !fir.heap<!fir.array<42xi32>>
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.barrier
-// CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
-  func.func @simple(%arg0: !fir.ref<!fir.array<42xi32>>) {
-    omp.parallel {
-      omp.workshare {
-        %c42 = arith.constant 42 : index
-        %c1_i32 = arith.constant 1 : i32
-        %0 = fir.shape %c42 : (index) -> !fir.shape<1>
-        %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-        %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-        %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-        %true = arith.constant true
-        %c1 = arith.constant 1 : index
-        "omp.workshare_loop_wrapper"() ({
-          omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
-            %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-            %8 = fir.load %7 : !fir.ref<i32>
-            %9 = arith.subi %8, %c1_i32 : i32
-            %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-            hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
-            omp.yield
-          }
-          omp.terminator
-        }) : () -> ()
-        %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-        fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
-        omp.terminator
-      }
-      omp.terminator
-    }
-    return
-  }
-}
+

>From 70ac29218c0750598fe8255846e62774f609a629 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 15:50:58 +0900
Subject: [PATCH 16/23] Hoist allocas

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 10 ++-
 .../Transforms/OpenMP/lower-workshare.mlir    | 69 +++++++++----------
 2 files changed, 41 insertions(+), 38 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 30af2556cf4ca..d0cd235d3eb07 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -163,7 +163,6 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
   OpBuilder copyFuncBuilder(m.getBodyRegion());
   fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
 
-  // TODO need to copyprivate the alloca's
   auto mapReloadedValue =
       [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
           OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
@@ -202,10 +201,17 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     SmallVector<Value> copyPrivate;
 
     for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
-      singleBuilder.clone(op, singleMapping);
       if (isSafeToParallelize(&op)) {
+        singleBuilder.clone(op, singleMapping);
         parallelBuilder.clone(op, rootMapping);
+      } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
+        auto hoisted =
+            cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
+        rootMapping.map(&*alloca, &*hoisted);
+        rootMapping.map(alloca.getResult(), hoisted.getResult());
+        copyPrivate.push_back(hoisted);
       } else {
+        singleBuilder.clone(op, singleMapping);
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
         for (auto res : op.getResults()) {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 19123e71cacf6..b78cfd80e17ac 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -1,5 +1,7 @@
 // RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
+// checks:
+// nowait on final omp.single
 func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   omp.parallel {
     omp.workshare {
@@ -37,6 +39,8 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // -----
 
+// checks:
+// fir.alloca hoisted out and copyprivate'd
 func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   omp.workshare {
     %c1_i32 = arith.constant 1 : i32
@@ -73,7 +77,6 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   return
 }
 
-
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
@@ -130,9 +133,9 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           return
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_workshare_copy_llvm_ptr(
-// CHECK-SAME:                                                %[[VAL_0:.*]]: !llvm.ptr,
-// CHECK-SAME:                                                %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK-LABEL:   func.func private @_workshare_copy_i32(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: !fir.ref<i32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: !fir.ref<i32>) {
 // CHECK:           return
 // CHECK:         }
 
@@ -141,46 +144,40 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK:           %[[VAL_5:.*]] = arith.constant true
-// CHECK:           fir.if %[[VAL_5]] {
-// CHECK:             %[[VAL_6:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
-// CHECK:             %[[VAL_7:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:             omp.single copyprivate(%[[VAL_6]] -> @_workshare_copy_llvm_ptr : !llvm.ptr, %[[VAL_7]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:               %[[VAL_8:.*]] = fir.alloca i32
-// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !fir.ref<i32> to !llvm.ptr
-// CHECK:               llvm.store %[[VAL_9]], %[[VAL_6]] : !llvm.ptr, !llvm.ptr
-// CHECK:               fir.store %[[VAL_3]] to %[[VAL_8]] : !fir.ref<i32>
-// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_12:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               fir.store %[[VAL_12]] to %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_4:.*]] = arith.constant true
+// CHECK:           fir.if %[[VAL_4]] {
+// CHECK:             %[[VAL_5:.*]] = fir.alloca i32
+// CHECK:             %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_6]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i32>
+// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_9:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_9]] to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             %[[VAL_14:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr -> !llvm.ptr
-// CHECK:             %[[VAL_15:.*]] = builtin.unrealized_conversion_cast %[[VAL_14]] : !llvm.ptr to !fir.ref<i32>
-// CHECK:             %[[VAL_16:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:             %[[VAL_17:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_16]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:             %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_18]](%[[VAL_16]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_11:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_20:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_17]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_23:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_24:.*]] = arith.subi %[[VAL_22]], %[[VAL_23]] : i32
-// CHECK:                 %[[VAL_25:.*]] = arith.subi %[[VAL_24]], %[[VAL_3]] : i32
-// CHECK:                 %[[VAL_26:.*]] = hlfir.designate %[[VAL_19]]#0 (%[[VAL_20]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_25]] to %[[VAL_26]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = fir.load %[[VAL_5]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_19:.*]] = arith.subi %[[VAL_17]], %[[VAL_18]] : i32
+// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_3]] : i32
+// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_20]] to %[[VAL_21]] temporary_lhs : i32, !fir.ref<i32>
 // CHECK:                 omp.yield
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.single nowait {
-// CHECK:               "test.test1"(%[[VAL_15]]) : (!fir.ref<i32>) -> ()
-// CHECK:               hlfir.assign %[[VAL_19]]#0 to %[[VAL_17]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_19]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:               "test.test1"(%[[VAL_5]]) : (!fir.ref<i32>) -> ()
+// CHECK:               hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.barrier

>From f9487fa88624d0aec129aadfe8c36a94bf48b3cd Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 15:52:23 +0900
Subject: [PATCH 17/23] More tests

---
 .../Transforms/OpenMP/lower-workshare2.mlir   | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare2.mlir

diff --git a/flang/test/Transforms/OpenMP/lower-workshare2.mlir b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
new file mode 100644
index 0000000000000..325a40d418445
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
@@ -0,0 +1,21 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @nonowait
+func.func @nonowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  // CHECK: omp.barrier
+  omp.workshare {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @nowait
+func.func @nowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  // CHECK-NOT: omp.barrier
+  omp.workshare nowait {
+    omp.terminator
+  }
+  return
+}

>From fb12bba97ac7d5bab63ee5a1a60d3bb8ca94af3b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:13:39 +0900
Subject: [PATCH 18/23] Emit body for copy func

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 44 +++++--------------
 .../Transforms/OpenMP/lower-workshare.mlir    |  6 +++
 2 files changed, 17 insertions(+), 33 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index d0cd235d3eb07..20f45296a8159 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -69,10 +69,6 @@ struct SingleRegion {
   Block::iterator begin, end;
 };
 
-static bool isSupportedByFirAlloca(Type ty) {
-  return !isa<fir::ReferenceType>(ty);
-}
-
 static bool mustParallelizeOp(Operation *op) {
   // TODO as in shouldUseWorkshareLowering we be careful not to pick up
   // workshare_loop_wrapper in nested omp.parallel ops
@@ -98,14 +94,10 @@ static bool isSafeToParallelize(Operation *op) {
 static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
                                          fir::FirOpBuilder builder) {
   mlir::ModuleOp module = builder.getModule();
-  std::string copyFuncName;
-  if (auto rt = dyn_cast<fir::ReferenceType>(varType)) {
-    mlir::Type eleTy = rt.getEleTy();
-    copyFuncName =
-        fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
-  } else {
-    copyFuncName = "_workshare_copy_llvm_ptr";
-  }
+  auto rt = cast<fir::ReferenceType>(varType);
+  mlir::Type eleTy = rt.getEleTy();
+  std::string copyFuncName =
+      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
 
   if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
     return decl;
@@ -120,6 +112,10 @@ static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
   builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
                       {loc, loc});
   builder.setInsertionPointToStart(&funcOp.getRegion().back());
+
+  Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(0));
+  builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(1));
+
   builder.create<mlir::func::ReturnOp>(loc);
   return funcOp;
 }
@@ -168,28 +164,10 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
     if (auto reloaded = rootMapping.lookupOrNull(v))
       return nullptr;
-    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
-    Value alloc, reloaded;
-    if (isSupportedByFirAlloca(ty)) {
-      alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
-      singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
-      reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
-    } else {
-      auto one = allocaBuilder.create<LLVM::ConstantOp>(
-          loc, allocaBuilder.getI32Type(), 1);
-      alloc =
-          allocaBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
-      Value toStore = singleBuilder
-                          .create<UnrealizedConversionCastOp>(
-                              loc, llvmPtrTy, singleMapping.lookup(v))
-                          .getResult(0);
-      singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
-      reloaded = parallelBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
-      reloaded =
-          parallelBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
-              .getResult(0);
-    }
+    Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
+    singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+    Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
     rootMapping.map(v, reloaded);
     return alloc;
   };
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index b78cfd80e17ac..997bc8d79f9b3 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -80,6 +80,8 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
 // CHECK:           return
 // CHECK:         }
 
@@ -130,12 +132,16 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
 // CHECK:           return
 // CHECK:         }
 
 // CHECK-LABEL:   func.func private @_workshare_copy_i32(
 // CHECK-SAME:                                           %[[VAL_0:.*]]: !fir.ref<i32>,
 // CHECK-SAME:                                           %[[VAL_1:.*]]: !fir.ref<i32>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<i32>
 // CHECK:           return
 // CHECK:         }
 

>From 9e8301120c0667338a583612ba8399e04e7b5aa7 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:59:34 +0900
Subject: [PATCH 19/23] Test the tmp storing logic

---
 .../Transforms/OpenMP/lower-workshare.mlir    |  2 -
 .../Transforms/OpenMP/lower-workshare3.mlir   | 74 +++++++++++++++++++
 2 files changed, 74 insertions(+), 2 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare3.mlir

diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 997bc8d79f9b3..063d3865065e0 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -36,7 +36,6 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   return
 }
 
-
 // -----
 
 // checks:
@@ -190,4 +189,3 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
-
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
new file mode 100644
index 0000000000000..84eded9450328
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -0,0 +1,74 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+
+// tests if the correct values are stored
+
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.parallel {
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK-NOT: fir.alloca
+    omp.workshare {
+
+      %t1 = "test.test1"() : () -> i32
+      // CHECK: %[[T1:.*]] = "test.test1"
+      // CHECK: fir.store %[[T1]]
+      %t2 = "test.test2"() : () -> i32
+      // CHECK: %[[T2:.*]] = "test.test2"
+      // CHECK: fir.store %[[T2]]
+      %t3 = "test.test3"() : () -> i32
+      // CHECK: %[[T3:.*]] = "test.test3"
+      // CHECK-NOT: fir.store %[[T3]]
+      %t4 = "test.test4"() : () -> i32
+      // CHECK: %[[T4:.*]] = "test.test4"
+      // CHECK: fir.store %[[T4]]
+      %t5 = "test.test5"() : () -> i32
+      // CHECK: %[[T5:.*]] = "test.test5"
+      // CHECK: fir.store %[[T5]]
+      %t6 = "test.test6"() : () -> i32
+      // CHECK: %[[T6:.*]] = "test.test6"
+      // CHECK-NOT: fir.store %[[T6]]
+
+
+      "test.test1"(%t1) : (i32) -> ()
+      "test.test1"(%t2) : (i32) -> ()
+      "test.test1"(%t3) : (i32) -> ()
+
+      %true = arith.constant true
+      fir.if %true {
+        "test.test2"(%t3) : (i32) -> ()
+      }
+
+      %c1_i32 = arith.constant 1 : i32
+
+      %t5_pure_use = arith.addi %t5, %c1_i32 : i32
+
+      %t6_mem_effect_use = "test.test8"(%t6) : (i32) -> i32
+      // CHECK: %[[T6_USE:.*]] = "test.test8"
+      // CHECK: fir.store %[[T6_USE]]
+
+      %c42 = arith.constant 42 : index
+      %c1 = arith.constant 1 : index
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          "test.test10"(%t1) : (i32) -> ()
+          "test.test10"(%t5_pure_use) : (i32) -> ()
+          "test.test10"(%t6_mem_effect_use) : (i32) -> ()
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+
+      "test.test10"(%t2) : (i32) -> ()
+      fir.if %true {
+        "test.test10"(%t4) : (i32) -> ()
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}

>From 6508a453e72e4f7563aae565045158bf42f31fba Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 17:33:52 +0900
Subject: [PATCH 20/23] Add workshare loop wrapper lowerings

---
 .../lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp  |  6 ++++--
 .../HLFIR/Transforms/OptimizedBufferization.cpp        | 10 +++++++---
 2 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index b608677c52631..1848dbe2c7a2c 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -26,12 +26,13 @@
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace hlfir {
@@ -792,7 +793,8 @@ struct ElementalOpConversion
     // Generate a loop nest looping around the fir.elemental shape and clone
     // fir.elemental region inside the inner loop.
     hlfir::LoopNest loopNest =
-        hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
+        hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(),
+                           flangomp::shouldUseWorkshareLowering(elemental));
     auto insPt = builder.saveInsertionPoint();
     builder.setInsertionPointToStart(loopNest.body);
     auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index c4aed6b79df92..150e3e9119724 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -20,6 +20,7 @@
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
 #include "flang/Optimizer/Transforms/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/Dominance.h"
@@ -482,7 +483,8 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
   // Generate a loop nest looping around the hlfir.elemental shape and clone
   // hlfir.elemental region inside the inner loop
   hlfir::LoopNest loopNest =
-      hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
+      hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(),
+                         flangomp::shouldUseWorkshareLowering(elemental));
   builder.setInsertionPointToStart(loopNest.body);
   auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
                                         loopNest.oneBasedIndices);
@@ -553,7 +555,8 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
   llvm::SmallVector<mlir::Value> extents =
       hlfir::getIndexExtents(loc, builder, shape);
   hlfir::LoopNest loopNest =
-      hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
+      hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
+                         flangomp::shouldUseWorkshareLowering(assign));
   builder.setInsertionPointToStart(loopNest.body);
   auto arrayElement =
       hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
@@ -648,7 +651,8 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
   llvm::SmallVector<mlir::Value> extents =
       hlfir::getIndexExtents(loc, builder, shape);
   hlfir::LoopNest loopNest =
-      hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
+      hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
+                         flangomp::shouldUseWorkshareLowering(assign));
   builder.setInsertionPointToStart(loopNest.body);
   auto rhsArrayElement =
       hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);

>From 7931ed5b294e69071925c56f85ba6cbf39c45631 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 20:24:36 +0900
Subject: [PATCH 21/23] Clean up trivially dead ops

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 32 ++++-------
 .../Transforms/OpenMP/lower-workshare3.mlir   |  2 +-
 .../Transforms/OpenMP/lower-workshare4.mlir   | 55 +++++++++++++++++++
 3 files changed, 68 insertions(+), 21 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare4.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 20f45296a8159..a147db2cb5d59 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -152,6 +152,14 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
   return false;
 }
 
+/// We clone pure operations in both the parallel and single blocks. this
+/// functions cleans them up if they end up with no uses
+static void cleanupBlock(Block *block) {
+  for (Operation &op : llvm::make_early_inc_range(*block))
+    if (isOpTriviallyDead(&op))
+      op.erase();
+}
+
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
   OpBuilder rootBuilder(sourceRegion.getContext());
@@ -258,13 +266,8 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         singleOperands.copyprivateVars =
             moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
                          singleBuilder, parallelBuilder);
+        cleanupBlock(singleBlock);
         for (auto var : singleOperands.copyprivateVars) {
-          Type ty;
-          if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
-            ty = firAlloca.getAllocatedType();
-          } else {
-            ty = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
-          }
           mlir::func::FuncOp funcOp =
               createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
           singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
@@ -302,6 +305,9 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 
     rootBuilder.clone(*block.getTerminator(), rootMapping);
   }
+
+  for (Block &targetBlock : targetRegion)
+    cleanupBlock(&targetBlock);
 }
 
 /// Lowers workshare to a sequence of single-thread regions and parallel loops
@@ -372,20 +378,6 @@ class LowerWorksharePass
 
       lowerWorkshare(wsOp);
     });
-
-    // Do folding
-    for (Operation *isolatedParent : parents) {
-      RewritePatternSet patterns(&getContext());
-      GreedyRewriteConfig config;
-      // prevent the pattern driver form merging blocks
-      config.enableRegionSimplification =
-          mlir::GreedySimplifyRegionLevel::Disabled;
-      if (failed(applyPatternsAndFoldGreedily(isolatedParent,
-                                              std::move(patterns), config))) {
-        emitError(isolatedParent->getLoc(), "error in lower workshare\n");
-        signalPassFailure();
-      }
-    }
   }
 };
 } // namespace
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
index 84eded9450328..aee95a464a31b 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -3,7 +3,7 @@
 
 // tests if the correct values are stored
 
-func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+func.func @wsfunc() {
   omp.parallel {
   // CHECK: fir.alloca
   // CHECK: fir.alloca
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
new file mode 100644
index 0000000000000..6cff0075b4fe5
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -0,0 +1,55 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+      %t1 = "test.test1"() : () -> i32
+
+      %c1 = arith.constant 1 : index
+      %c42 = arith.constant 42 : index
+
+      %c2 = arith.constant 2 : index
+      "test.test3"(%c2) : (index) -> ()
+
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          "test.test2"() : () -> ()
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @wsfunc() {
+// CHECK:           %[[VAL_0:.*]] = fir.alloca i32
+// CHECK:           omp.parallel {
+// CHECK:             %[[VAL_1:.*]] = arith.constant true
+// CHECK:             fir.if %[[VAL_1]] {
+// CHECK:               omp.single {
+// CHECK:                 %[[VAL_2:.*]] = "test.test1"() : () -> i32
+// CHECK:                 %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK:                 "test.test3"(%[[VAL_3]]) : (index) -> ()
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_5:.*]] = arith.constant 42 : index
+// CHECK:               omp.wsloop nowait {
+// CHECK:                 omp.loop_nest (%[[VAL_6:.*]]) : index = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_4]]) {
+// CHECK:                   "test.test2"() : () -> ()
+// CHECK:                   omp.yield
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.barrier
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+

>From 828db377f74059a02b1861da48d5f7b9e47779c4 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 21:55:14 +0900
Subject: [PATCH 22/23] Only handle single-block regions for now

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp |  80 +++++----
 .../Transforms/OpenMP/lower-workshare.mlir    | 154 +++++++++---------
 .../Transforms/OpenMP/lower-workshare4.mlir   |  31 ++--
 3 files changed, 143 insertions(+), 122 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index a147db2cb5d59..5998489c13d38 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -13,12 +13,14 @@
 #include <flang/Optimizer/Dialect/FIRType.h>
 #include <flang/Optimizer/HLFIR/HLFIROps.h>
 #include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/BreadthFirstIterator.h>
 #include <llvm/ADT/STLExtras.h>
 #include <llvm/ADT/SmallVectorExtras.h>
 #include <llvm/ADT/iterator_range.h>
 #include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
 #include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
 #include <mlir/IR/BuiltinOps.h>
@@ -161,7 +163,8 @@ static void cleanupBlock(Block *block) {
 }
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
-                              IRMapping &rootMapping, Location loc) {
+                              IRMapping &rootMapping, Location loc,
+                              mlir::DominanceInfo &di) {
   OpBuilder rootBuilder(sourceRegion.getContext());
   ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
   OpBuilder copyFuncBuilder(m.getBodyRegion());
@@ -214,14 +217,19 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     return copyPrivate;
   };
 
-  // TODO Need to handle these (clone them) in dominator tree order
   for (Block &block : sourceRegion) {
-    rootBuilder.createBlock(
+    Block *targetBlock = rootBuilder.createBlock(
         &targetRegion, {}, block.getArgumentTypes(),
         llvm::map_to_vector(block.getArguments(),
                             [](BlockArgument arg) { return arg.getLoc(); }));
-    Operation *terminator = block.getTerminator();
+    rootMapping.map(&block, targetBlock);
+    rootMapping.map(block.getArguments(), targetBlock->getArguments());
+  }
 
+  auto handleOneBlock = [&](Block &block) {
+    Block &targetBlock = *rootMapping.lookup(&block);
+    rootBuilder.setInsertionPointToStart(&targetBlock);
+    Operation *terminator = block.getTerminator();
     SmallVector<std::variant<SingleRegion, Operation *>> regions;
 
     auto it = block.begin();
@@ -298,12 +306,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
           for (auto [region, clonedRegion] :
                llvm::zip(op->getRegions(), cloned->getRegions()))
-            parallelizeRegion(region, clonedRegion, rootMapping, loc);
+            parallelizeRegion(region, clonedRegion, rootMapping, loc, di);
         }
       }
     }
 
     rootBuilder.clone(*block.getTerminator(), rootMapping);
+  };
+
+  if (sourceRegion.hasOneBlock()) {
+    handleOneBlock(sourceRegion.front());
+  } else {
+    auto &domTree = di.getDomTree(&sourceRegion);
+    for (auto node : llvm::breadth_first(domTree.getRootNode())) {
+      handleOneBlock(*node->getBlock());
+    }
   }
 
   for (Block &targetBlock : targetRegion)
@@ -336,47 +353,46 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 ///
 /// Note that we allocate temporary memory for values in omp.single's which need
 /// to be accessed in all threads in the closest omp.parallel
-void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) {
   Location loc = wsOp->getLoc();
   IRMapping rootMapping;
 
   OpBuilder rootBuilder(wsOp);
 
-  // TODO We need something like an scf;execute here, but that is not registered
-  // so using fir.if for now but it looks like it does not support multiple
-  // blocks so it doesnt work for multi block case...
-  auto ifOp = rootBuilder.create<fir::IfOp>(
-      loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
-  ifOp.getThenRegion().front().erase();
-
-  parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);
-
-  Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
-  assert(isa<omp::TerminatorOp>(terminatorOp));
-  OpBuilder termBuilder(terminatorOp);
-
+  // TODO We need something like an scf.execute here, but that is not registered
+  // so using omp.workshare as a placeholder. We need this op as our
+  // parallelizeRegion works on regions and not blocks.
+  omp::WorkshareOp newOp =
+      rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands());
   if (!wsOp.getNowait())
-    termBuilder.create<omp::BarrierOp>(loc);
-
-  termBuilder.create<fir::ResultOp>(loc, ValueRange());
-
-  terminatorOp->erase();
+    rootBuilder.create<omp::BarrierOp>(loc);
+
+  parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc, di);
+
+  if (wsOp.getRegion().getBlocks().size() != 1)
+    return failure();
+
+  // Inline the contents of the placeholder workshare op into its parent block.
+  Block *theBlock = &newOp.getRegion().front();
+  Operation *term = theBlock->getTerminator();
+  Block *parentBlock = wsOp->getBlock();
+  parentBlock->getOperations().splice(newOp->getIterator(),
+                                      theBlock->getOperations());
+  assert(term->getNumOperands() == 0);
+  term->erase();
+  newOp->erase();
   wsOp->erase();
-
-  return;
+  return success();
 }
 
 class LowerWorksharePass
     : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
 public:
   void runOnOperation() override {
-    SmallPtrSet<Operation *, 8> parents;
+    mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>();
     getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
-      Operation *isolatedParent =
-          wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
-      parents.insert(isolatedParent);
-
-      lowerWorkshare(wsOp);
+      if (failed(lowerWorkshare(wsOp, di)))
+        signalPassFailure();
     });
   }
 };
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 063d3865065e0..b31e951223d56 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -86,43 +86,46 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant true
 // CHECK:           omp.parallel {
-// CHECK:             fir.if %[[VAL_4]] {
-// CHECK:               %[[VAL_5:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:               omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:                 %[[VAL_6:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:                 %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_6]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:                 %[[VAL_8:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:                 fir.store %[[VAL_8]] to %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:                 %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_6]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               omp.wsloop {
-// CHECK:                 omp.loop_nest (%[[VAL_14:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                   %[[VAL_15:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_14]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                   %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
-// CHECK:                   %[[VAL_17:.*]] = arith.subi %[[VAL_16]], %[[VAL_2]] : i32
-// CHECK:                   %[[VAL_18:.*]] = hlfir.designate %[[VAL_13]]#0 (%[[VAL_14]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                   hlfir.assign %[[VAL_17]] to %[[VAL_18]] temporary_lhs : i32, !fir.ref<i32>
-// CHECK:                   omp.yield
-// CHECK:                 }
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               omp.single nowait {
-// CHECK:                 hlfir.assign %[[VAL_13]]#0 to %[[VAL_11]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:                 fir.freemem %[[VAL_13]]#0 : !fir.heap<!fir.array<42xi32>>
-// CHECK:                 omp.terminator
+// CHECK:             %[[VAL_1:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               %[[VAL_2:.*]] = arith.constant 42 : index
+// CHECK:               %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_5]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_7:.*]] = arith.constant 42 : index
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : i32
+// CHECK:             %[[VAL_9:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_9]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_11:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_11]](%[[VAL_9]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_13:.*]] = arith.constant true
+// CHECK:             %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_14]]) to (%[[VAL_7]]) inclusive step (%[[VAL_14]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_10]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_8]] : i32
+// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                 omp.yield
 // CHECK:               }
-// CHECK:               omp.barrier
+// CHECK:               omp.terminator
 // CHECK:             }
+// CHECK:             omp.single nowait {
+// CHECK:               %[[VAL_20:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               %[[VAL_21:.*]] = fir.insert_value %[[VAL_20]], %[[VAL_13]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               hlfir.assign %[[VAL_12]]#0 to %[[VAL_10]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_22:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             %[[VAL_23:.*]] = fir.insert_value %[[VAL_22]], %[[VAL_13]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             omp.barrier
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return
@@ -146,46 +149,51 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_4:.*]] = arith.constant true
-// CHECK:           fir.if %[[VAL_4]] {
-// CHECK:             %[[VAL_5:.*]] = fir.alloca i32
-// CHECK:             %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:             omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_6]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:               fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i32>
-// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_9:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               fir.store %[[VAL_9]] to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               omp.terminator
-// CHECK:             }
-// CHECK:             %[[VAL_11:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:             %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_18:.*]] = fir.load %[[VAL_5]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_19:.*]] = arith.subi %[[VAL_17]], %[[VAL_18]] : i32
-// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_3]] : i32
-// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_20]] to %[[VAL_21]] temporary_lhs : i32, !fir.ref<i32>
-// CHECK:                 omp.yield
-// CHECK:               }
-// CHECK:               omp.terminator
-// CHECK:             }
-// CHECK:             omp.single nowait {
-// CHECK:               "test.test1"(%[[VAL_5]]) : (!fir.ref<i32>) -> ()
-// CHECK:               hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
-// CHECK:               omp.terminator
+// CHECK:           %[[VAL_1:.*]] = fir.alloca i32
+// CHECK:           %[[VAL_2:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_2]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_1]] : !fir.ref<i32>
+// CHECK:             %[[VAL_4:.*]] = arith.constant 42 : index
+// CHECK:             %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_7:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:             fir.store %[[VAL_7]] to %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_7]](%[[VAL_5]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           %[[VAL_9:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_10:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_11:.*]] = fir.shape %[[VAL_10]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_13:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_15:.*]] = arith.constant true
+// CHECK:           %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:           omp.wsloop {
+// CHECK:             omp.loop_nest (%[[VAL_17:.*]]) : index = (%[[VAL_16]]) to (%[[VAL_10]]) inclusive step (%[[VAL_16]]) {
+// CHECK:               %[[VAL_18:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_17]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:               %[[VAL_19:.*]] = fir.load %[[VAL_18]] : !fir.ref<i32>
+// CHECK:               %[[VAL_20:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
+// CHECK:               %[[VAL_21:.*]] = arith.subi %[[VAL_19]], %[[VAL_20]] : i32
+// CHECK:               %[[VAL_22:.*]] = arith.subi %[[VAL_21]], %[[VAL_9]] : i32
+// CHECK:               %[[VAL_23:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_17]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:               hlfir.assign %[[VAL_22]] to %[[VAL_23]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.yield
 // CHECK:             }
-// CHECK:             omp.barrier
+// CHECK:             omp.terminator
 // CHECK:           }
+// CHECK:           omp.single nowait {
+// CHECK:             %[[VAL_24:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             %[[VAL_25:.*]] = fir.insert_value %[[VAL_24]], %[[VAL_15]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             "test.test1"(%[[VAL_1]]) : (!fir.ref<i32>) -> ()
+// CHECK:             hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:             fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           %[[VAL_26:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:           %[[VAL_27:.*]] = fir.insert_value %[[VAL_26]], %[[VAL_15]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:           omp.barrier
 // CHECK:           return
 // CHECK:         }
+
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
index 6cff0075b4fe5..d695a1c354517 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare4.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -29,25 +29,22 @@ func.func @wsfunc() {
 // CHECK-LABEL:   func.func @wsfunc() {
 // CHECK:           %[[VAL_0:.*]] = fir.alloca i32
 // CHECK:           omp.parallel {
-// CHECK:             %[[VAL_1:.*]] = arith.constant true
-// CHECK:             fir.if %[[VAL_1]] {
-// CHECK:               omp.single {
-// CHECK:                 %[[VAL_2:.*]] = "test.test1"() : () -> i32
-// CHECK:                 %[[VAL_3:.*]] = arith.constant 2 : index
-// CHECK:                 "test.test3"(%[[VAL_3]]) : (index) -> ()
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:               %[[VAL_5:.*]] = arith.constant 42 : index
-// CHECK:               omp.wsloop nowait {
-// CHECK:                 omp.loop_nest (%[[VAL_6:.*]]) : index = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_4]]) {
-// CHECK:                   "test.test2"() : () -> ()
-// CHECK:                   omp.yield
-// CHECK:                 }
-// CHECK:                 omp.terminator
+// CHECK:             omp.single {
+// CHECK:               %[[VAL_1:.*]] = "test.test1"() : () -> i32
+// CHECK:               %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:               "test.test3"(%[[VAL_2]]) : (index) -> ()
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_4:.*]] = arith.constant 42 : index
+// CHECK:             omp.wsloop nowait {
+// CHECK:               omp.loop_nest (%[[VAL_5:.*]]) : index = (%[[VAL_3]]) to (%[[VAL_4]]) inclusive step (%[[VAL_3]]) {
+// CHECK:                 "test.test2"() : () -> ()
+// CHECK:                 omp.yield
 // CHECK:               }
-// CHECK:               omp.barrier
+// CHECK:               omp.terminator
 // CHECK:             }
+// CHECK:             omp.barrier
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return

>From f8834c9285df0ea2a43aae3e9c69e5694fb37345 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 6 Aug 2024 13:52:20 +0900
Subject: [PATCH 23/23] Fix tests for custom assembly for loop wrapper

---
 flang/test/HLFIR/bufferize-workshare.fir           | 4 ++--
 flang/test/Transforms/OpenMP/lower-workshare.mlir  | 8 ++++----
 flang/test/Transforms/OpenMP/lower-workshare3.mlir | 4 ++--
 flang/test/Transforms/OpenMP/lower-workshare4.mlir | 4 ++--
 4 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/flang/test/HLFIR/bufferize-workshare.fir b/flang/test/HLFIR/bufferize-workshare.fir
index 33b368a62eaab..02fe6b1d1799c 100644
--- a/flang/test/HLFIR/bufferize-workshare.fir
+++ b/flang/test/HLFIR/bufferize-workshare.fir
@@ -12,7 +12,7 @@
 // CHECK:               %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:               %[[VAL_7:.*]] = arith.constant true
 // CHECK:               %[[VAL_8:.*]] = arith.constant 1 : index
-// CHECK:               "omp.workshare_loop_wrapper"() ({
+// CHECK:               omp.workshare_loop_wrapper {
 // CHECK:                 omp.loop_nest (%[[VAL_9:.*]]) : index = (%[[VAL_8]]) to (%[[VAL_1]]) inclusive step (%[[VAL_8]]) {
 // CHECK:                   %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_9]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
 // CHECK:                   %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
@@ -22,7 +22,7 @@
 // CHECK:                   omp.yield
 // CHECK:                 }
 // CHECK:                 omp.terminator
-// CHECK:               }) : () -> ()
+// CHECK:               }
 // CHECK:               %[[VAL_14:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
 // CHECK:               %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_7]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
 // CHECK:               %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_6]]#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index b31e951223d56..9347863dc4a60 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -13,7 +13,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
       %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
       %true = arith.constant true
       %c1 = arith.constant 1 : index
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
           %8 = fir.load %7 : !fir.ref<i32>
@@ -23,7 +23,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
       %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
       %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
       %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
@@ -52,7 +52,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
     %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
     %true = arith.constant true
     %c1 = arith.constant 1 : index
-    "omp.workshare_loop_wrapper"() ({
+    omp.workshare_loop_wrapper {
       omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
         %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
         %8 = fir.load %7 : !fir.ref<i32>
@@ -64,7 +64,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
         omp.yield
       }
       omp.terminator
-    }) : () -> ()
+    }
     %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
     %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
     %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
index aee95a464a31b..afb41d95e7198 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -52,7 +52,7 @@ func.func @wsfunc() {
 
       %c42 = arith.constant 42 : index
       %c1 = arith.constant 1 : index
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           "test.test10"(%t1) : (i32) -> ()
           "test.test10"(%t5_pure_use) : (i32) -> ()
@@ -60,7 +60,7 @@ func.func @wsfunc() {
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
 
       "test.test10"(%t2) : (i32) -> ()
       fir.if %true {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
index d695a1c354517..0a70007a9e78d 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare4.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -12,13 +12,13 @@ func.func @wsfunc() {
       %c2 = arith.constant 2 : index
       "test.test3"(%c2) : (index) -> ()
 
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           "test.test2"() : () -> ()
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
       omp.terminator
     }
     omp.terminator



More information about the llvm-branch-commits mailing list