[llvm-branch-commits] [flang] [mlir] [flang] Lower omp.workshare to other omp constructs (PR #101446)

Ivan R. Ivanov via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Aug 19 01:41:04 PDT 2024


https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/101446

>From ef896d238882550e8dac3d26e5628625c84044c6 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 17:40:13 +0900
Subject: [PATCH 01/36] Fix typo

---
 flang/include/flang/Optimizer/OpenMP/Passes.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 7c203efbb5ee1..2c7d8df1a3290 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -1,4 +1,4 @@
-//===-- Passes.td - HLFIR pass definition file -------------*- tablegen -*-===//
+//===-- Passes.td - flang OpenMP pass definition -----------*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

>From 78f451caa0e0df33e4bae00b6b4771ba8fbc3e18 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 12:33:49 +0900
Subject: [PATCH 02/36] Rename pass names

---
 flang/docs/OpenMP-declare-target.md                    |  4 ++--
 flang/docs/OpenMP-descriptor-management.md             |  4 ++--
 flang/include/flang/Optimizer/OpenMP/Passes.td         |  6 +++---
 flang/include/flang/Tools/CLOptions.inc                |  8 ++++----
 flang/lib/Optimizer/OpenMP/CMakeLists.txt              |  6 +++---
 ...{OMPFunctionFiltering.cpp => FunctionFiltering.cpp} |  9 ++++-----
 ...MapInfoFinalization.cpp => MapInfoFinalization.cpp} | 10 +++++-----
 ...{OMPMarkDeclareTarget.cpp => MarkDeclareTarget.cpp} |  7 +++----
 8 files changed, 26 insertions(+), 28 deletions(-)
 rename flang/lib/Optimizer/OpenMP/{OMPFunctionFiltering.cpp => FunctionFiltering.cpp} (94%)
 rename flang/lib/Optimizer/OpenMP/{OMPMapInfoFinalization.cpp => MapInfoFinalization.cpp} (97%)
 rename flang/lib/Optimizer/OpenMP/{OMPMarkDeclareTarget.cpp => MarkDeclareTarget.cpp} (95%)

diff --git a/flang/docs/OpenMP-declare-target.md b/flang/docs/OpenMP-declare-target.md
index d29a46807e1ea..45062469007b6 100644
--- a/flang/docs/OpenMP-declare-target.md
+++ b/flang/docs/OpenMP-declare-target.md
@@ -149,7 +149,7 @@ flang/lib/Lower/OpenMP.cpp function `genDeclareTargetIntGlobal`.
 
 There are currently two passes within Flang that are related to the processing 
 of `declare target`:
-* `OMPMarkDeclareTarget` - This pass is in charge of marking functions captured
+* `MarkDeclareTarget` - This pass is in charge of marking functions captured
 (called from) in `target` regions or other `declare target` marked functions as
 `declare target`. It does so recursively, i.e. nested calls will also be 
 implicitly marked. It currently will try to mark things as conservatively as 
@@ -157,7 +157,7 @@ possible, e.g. if captured in a `target` region it will apply `nohost`, unless
 it encounters a `host` `declare target` in which case it will apply the `any` 
 device type. Functions are handled similarly, except we utilise the parent's 
 device type where possible.
-* `OMPFunctionFiltering` - This is executed after the `OMPMarkDeclareTarget`
+* `FunctionFiltering` - This is executed after the `MarkDeclareTarget`
 pass, and its job is to conservatively remove host functions from
 the module where possible when compiling for the device. This helps make 
 sure that most incompatible code for the host is not lowered for the 
diff --git a/flang/docs/OpenMP-descriptor-management.md b/flang/docs/OpenMP-descriptor-management.md
index d0eb01b00f9bb..cdc72f3cac3a4 100644
--- a/flang/docs/OpenMP-descriptor-management.md
+++ b/flang/docs/OpenMP-descriptor-management.md
@@ -44,7 +44,7 @@ Currently, Flang will lower these descriptor types in the OpenMP lowering (lower
 to all other map types, generating an omp.MapInfoOp containing relevant information required for lowering
 the OpenMP dialect to LLVM-IR during the final stages of the MLIR lowering. However, after 
 the lowering to FIR/HLFIR has been performed an OpenMP dialect specific pass for Fortran, 
-`OMPMapInfoFinalizationPass` (Optimizer/OMPMapInfoFinalization.cpp) will expand the 
+`MapInfoFinalizationPass` (Optimizer/MapInfoFinalization.cpp) will expand the 
 `omp.MapInfoOp`'s containing descriptors (which currently will be a `BoxType` or `BoxAddrOp`) into multiple 
 mappings, with one extra per pointer member in the descriptor that is supported on top of the original
 descriptor map operation. These pointers members are linked to the parent descriptor by adding them to 
@@ -53,7 +53,7 @@ owning operation's (`omp.TargetOp`, `omp.TargetDataOp` etc.) map operand list an
 operation is `IsolatedFromAbove`, it also inserts them as `BlockArgs` to canonicalize the mappings and
 simplify lowering.
 
-An example transformation by the `OMPMapInfoFinalizationPass`:
+An example transformation by the `MapInfoFinalizationPass`:
 
 ```
 
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 2c7d8df1a3290..395178e26a576 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -11,7 +11,7 @@
 
 include "mlir/Pass/PassBase.td"
 
-def OMPMapInfoFinalizationPass
+def MapInfoFinalizationPass
     : Pass<"omp-map-info-finalization"> {
   let summary = "expands OpenMP MapInfo operations containing descriptors";
   let description = [{
@@ -22,13 +22,13 @@ def OMPMapInfoFinalizationPass
   let dependentDialects = ["mlir::omp::OpenMPDialect"];
 }
 
-def OMPMarkDeclareTargetPass
+def MarkDeclareTargetPass
     : Pass<"omp-mark-declare-target", "mlir::ModuleOp"> {
   let summary = "Marks all functions called by an OpenMP declare target function as declare target";
   let dependentDialects = ["mlir::omp::OpenMPDialect"];
 }
 
-def OMPFunctionFiltering : Pass<"omp-function-filtering"> {
+def FunctionFiltering : Pass<"omp-function-filtering"> {
   let summary = "Filters out functions intended for the host when compiling "
                 "for the target device.";
   let dependentDialects = [
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 1ad74a98c8d95..05b2f31711add 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -17,8 +17,8 @@
 #include "mlir/Transforms/Passes.h"
 #include "flang/Optimizer/CodeGen/CodeGen.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
-#include "flang/Optimizer/Transforms/Passes.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "flang/Optimizer/Transforms/Passes.h"
 #include "llvm/Passes/OptimizationLevel.h"
 #include "llvm/Support/CommandLine.h"
 #include <type_traits>
@@ -359,10 +359,10 @@ inline void createHLFIRToFIRPassPipeline(
 inline void createOpenMPFIRPassPipeline(
     mlir::PassManager &pm, bool isTargetDevice) {
   addNestedPassToAllTopLevelOperations(
-      pm, flangomp::createOMPMapInfoFinalizationPass);
-  pm.addPass(flangomp::createOMPMarkDeclareTargetPass());
+      pm, flangomp::createMapInfoFinalizationPass);
+  pm.addPass(flangomp::createMarkDeclareTargetPass());
   if (isTargetDevice)
-    pm.addPass(flangomp::createOMPFunctionFiltering());
+    pm.addPass(flangomp::createFunctionFiltering());
 }
 
 #if !defined(FLANG_EXCLUDE_CODEGEN)
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index c31b490c87f67..a8984d256b8f6 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -1,9 +1,9 @@
 get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
 
 add_flang_library(FlangOpenMPTransforms
-  OMPFunctionFiltering.cpp
-  OMPMapInfoFinalization.cpp
-  OMPMarkDeclareTarget.cpp
+  FunctionFiltering.cpp
+  MapInfoFinalization.cpp
+  MarkDeclareTarget.cpp
 
   DEPENDS
   FIRDialect
diff --git a/flang/lib/Optimizer/OpenMP/OMPFunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
similarity index 94%
rename from flang/lib/Optimizer/OpenMP/OMPFunctionFiltering.cpp
rename to flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index 2011bd56352ff..b889b3a137841 100644
--- a/flang/lib/Optimizer/OpenMP/OMPFunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -1,4 +1,4 @@
-//===- OMPFunctionFiltering.cpp -------------------------------------------===//
+//===- FunctionFiltering.cpp -------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -29,11 +29,10 @@ namespace flangomp {
 using namespace mlir;
 
 namespace {
-class OMPFunctionFilteringPass
-    : public flangomp::impl::OMPFunctionFilteringBase<
-          OMPFunctionFilteringPass> {
+class FunctionFilteringPass
+    : public flangomp::impl::FunctionFilteringBase<FunctionFilteringPass> {
 public:
-  OMPFunctionFilteringPass() = default;
+  FunctionFilteringPass() = default;
 
   void runOnOperation() override {
     MLIRContext *context = &getContext();
diff --git a/flang/lib/Optimizer/OpenMP/OMPMapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
similarity index 97%
rename from flang/lib/Optimizer/OpenMP/OMPMapInfoFinalization.cpp
rename to flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 84fd8b28fb6a0..be115be29fdd2 100644
--- a/flang/lib/Optimizer/OpenMP/OMPMapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -1,4 +1,4 @@
-//===- OMPMapInfoFinalization.cpp -----------------------------------------===//
+//===- MapInfoFinalization.cpp -----------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -46,9 +46,9 @@ namespace flangomp {
 } // namespace flangomp
 
 namespace {
-class OMPMapInfoFinalizationPass
-    : public flangomp::impl::OMPMapInfoFinalizationPassBase<
-          OMPMapInfoFinalizationPass> {
+class MapInfoFinalizationPass
+    : public flangomp::impl::MapInfoFinalizationPassBase<
+          MapInfoFinalizationPass> {
 
   void genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
                                fir::FirOpBuilder &builder,
@@ -244,7 +244,7 @@ class OMPMapInfoFinalizationPass
       // all users appropriately, making sure to only add a single member link
       // per new generation for the original originating descriptor MapInfoOp.
       assert(llvm::hasSingleElement(op->getUsers()) &&
-             "OMPMapInfoFinalization currently only supports single users "
+             "MapInfoFinalization currently only supports single users "
              "of a MapInfoOp");
 
       if (!op.getMembers().empty()) {
diff --git a/flang/lib/Optimizer/OpenMP/OMPMarkDeclareTarget.cpp b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
similarity index 95%
rename from flang/lib/Optimizer/OpenMP/OMPMarkDeclareTarget.cpp
rename to flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
index b36c2af91bfe3..5feeba230ef97 100644
--- a/flang/lib/Optimizer/OpenMP/OMPMarkDeclareTarget.cpp
+++ b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
@@ -1,4 +1,4 @@
-//===- OMPMarkDeclareTarget.cpp -------------------------------------------===//
+//===- MarkDeclareTarget.cpp -------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -28,9 +28,8 @@ namespace flangomp {
 } // namespace flangomp
 
 namespace {
-class OMPMarkDeclareTargetPass
-    : public flangomp::impl::OMPMarkDeclareTargetPassBase<
-          OMPMarkDeclareTargetPass> {
+class MarkDeclareTargetPass
+    : public flangomp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
 
   void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
                        mlir::omp::DeclareTargetCaptureClause parentCapClause,

>From 70652228a4851f2b4cf03cabe76da5dd96dc0bb9 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 12:42:09 +0900
Subject: [PATCH 03/36] Fix defines

---
 flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp   | 2 +-
 flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp | 2 +-
 flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp   | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index b889b3a137841..bd9005d3e2df6 100644
--- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -22,7 +22,7 @@
 #include "llvm/ADT/SmallVector.h"
 
 namespace flangomp {
-#define GEN_PASS_DEF_OMPFUNCTIONFILTERING
+#define GEN_PASS_DEF_FUNCTIONFILTERING
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
 } // namespace flangomp
 
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index be115be29fdd2..6e9cd03dca8f3 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -41,7 +41,7 @@
 #include <iterator>
 
 namespace flangomp {
-#define GEN_PASS_DEF_OMPMAPINFOFINALIZATIONPASS
+#define GEN_PASS_DEF_MAPINFOFINALIZATIONPASS
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
 } // namespace flangomp
 
diff --git a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
index 5feeba230ef97..a7ffd5fda82b7 100644
--- a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
+++ b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
@@ -23,7 +23,7 @@
 #include "llvm/ADT/SmallPtrSet.h"
 
 namespace flangomp {
-#define GEN_PASS_DEF_OMPMARKDECLARETARGETPASS
+#define GEN_PASS_DEF_MARKDECLARETARGETPASS
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
 } // namespace flangomp
 

>From 604b0293e0574e9d697d4071c2b853a5a27af1e1 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:09:09 +0900
Subject: [PATCH 04/36] [MLIR][omp] Add omp.workshare op

---
 .../Dialect/OpenMP/OpenMPClauseOperands.h     |  3 +++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 22 +++++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 13 +++++++++++
 3 files changed, 38 insertions(+)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index 38e4d8f245e4f..d14e5e17afbb0 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -17,6 +17,7 @@
 
 #include "mlir/IR/BuiltinAttributes.h"
 #include "llvm/ADT/SmallVector.h"
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 
 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
 
@@ -316,6 +317,8 @@ using TeamsOperands =
     detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
                     PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
 
+using WorkshareOperands = detail::Clauses<NowaitClauseOps>;
+
 using WsloopOperands =
     detail::Clauses<AllocateClauseOps, LinearClauseOps, NowaitClauseOps,
                     OrderClauseOps, OrderedClauseOps, PrivateClauseOps,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 68f92e6952694..5199ff50abb95 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -286,6 +286,28 @@ def SingleOp : OpenMP_Op<"single", traits = [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// 2.8.3 Workshare Construct
+//===----------------------------------------------------------------------===//
+
+def WorkshareOp : OpenMP_Op<"workshare", clauses = [
+    OpenMP_NowaitClause,
+  ], singleRegion = true> {
+  let summary = "workshare directive";
+  let description = [{
+    The workshare construct divides the execution of the enclosed structured
+    block into separate units of work, and causes the threads of the team to
+    share the work such that each unit is executed only once by one thread, in
+    the context of its implicit task
+  }] # clausesDescription;
+
+  let builders = [
+    OpBuilder<(ins CArg<"const WorkshareOperands &">:$clauses)>
+  ];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Loop Nest
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 11780f84697b1..9a189eb2059e0 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1683,6 +1683,19 @@ LogicalResult SingleOp::verify() {
                                   getCopyprivateSyms());
 }
 
+//===----------------------------------------------------------------------===//
+// WorkshareOp
+//===----------------------------------------------------------------------===//
+
+void WorkshareOp::build(OpBuilder &builder, OperationState &state,
+                        const WorkshareOperands &clauses) {
+  WorkshareOp::build(builder, state, clauses.nowait);
+}
+
+LogicalResult WorkshareOp::verify() {
+  return (*this)->getRegion(0).getBlocks().size() == 1 ? success() : failure();
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//

>From f2fd4f278c23ec99dae3ac44e1c05fcb629f707d Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:10:25 +0900
Subject: [PATCH 05/36] Add custom omp loop wrapper

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5199ff50abb95..76f0c472cfdb1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -308,6 +308,17 @@ def WorkshareOp : OpenMP_Op<"workshare", clauses = [
   let hasVerifier = 1;
 }
 
+def WorkshareLoopWrapperOp : OpenMP_Op<"workshare_loop_wrapper", traits = [
+    DeclareOpInterfaceMethods<LoopWrapperInterface>,
+    RecursiveMemoryEffects, SingleBlock
+  ], singleRegion = true> {
+  let summary = "contains loop nests to be parallelized by workshare";
+
+  let builders = [
+    OpBuilder<(ins), [{ build($_builder, $_state, {}); }]>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // Loop Nest
 //===----------------------------------------------------------------------===//

>From 22c66e6db3997e38254d9848661a38627cd7bb19 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:08:58 +0900
Subject: [PATCH 06/36] Add recursive memory effects trait to workshare

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 76f0c472cfdb1..7d1c80333855e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -290,7 +290,9 @@ def SingleOp : OpenMP_Op<"single", traits = [
 // 2.8.3 Workshare Construct
 //===----------------------------------------------------------------------===//
 
-def WorkshareOp : OpenMP_Op<"workshare", clauses = [
+def WorkshareOp : OpenMP_Op<"workshare", traits = [
+    RecursiveMemoryEffects,
+  ], clauses = [
     OpenMP_NowaitClause,
   ], singleRegion = true> {
   let summary = "workshare directive";

>From 5d6094bc98a5a63cf1db6ef9eb2a34337acb86ed Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 17:04:07 +0900
Subject: [PATCH 07/36] Remove stray include

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index d14e5e17afbb0..896ca9581c3fc 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -17,7 +17,6 @@
 
 #include "mlir/IR/BuiltinAttributes.h"
 #include "llvm/ADT/SmallVector.h"
-#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 
 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
 

>From e41c776df6dd0af31b6739448323dc704c9716f7 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 21:56:13 +0900
Subject: [PATCH 08/36] Remove omp.workshare verifier

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 2 --
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 4 ----
 2 files changed, 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 7d1c80333855e..863cd81923c87 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -306,8 +306,6 @@ def WorkshareOp : OpenMP_Op<"workshare", traits = [
   let builders = [
     OpBuilder<(ins CArg<"const WorkshareOperands &">:$clauses)>
   ];
-
-  let hasVerifier = 1;
 }
 
 def WorkshareLoopWrapperOp : OpenMP_Op<"workshare_loop_wrapper", traits = [
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 9a189eb2059e0..6c1b77077bdba 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1692,10 +1692,6 @@ void WorkshareOp::build(OpBuilder &builder, OperationState &state,
   WorkshareOp::build(builder, state, clauses.nowait);
 }
 
-LogicalResult WorkshareOp::verify() {
-  return (*this)->getRegion(0).getBlocks().size() == 1 ? success() : failure();
-}
-
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//

>From 94a2e49f1ce0e8d95287041f64909ec23bd189ed Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 6 Aug 2024 13:41:22 +0900
Subject: [PATCH 09/36] Add assembly format for wrapper and add test

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  2 +-
 mlir/test/Dialect/OpenMP/ops.mlir             | 61 +++++++++++++++++++
 2 files changed, 62 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 863cd81923c87..0f29e911cb2f2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -313,10 +313,10 @@ def WorkshareLoopWrapperOp : OpenMP_Op<"workshare_loop_wrapper", traits = [
     RecursiveMemoryEffects, SingleBlock
   ], singleRegion = true> {
   let summary = "contains loop nests to be parallelized by workshare";
-
   let builders = [
     OpBuilder<(ins), [{ build($_builder, $_state, {}); }]>
   ];
+  let assemblyFormat = "$region attr-dict";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d2924998f41b8..981e3fbb0306b 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2789,3 +2789,64 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
 
   return
 }
+
+// CHECK-LABEL: func @omp_workshare
+func.func @omp_workshare() {
+  // CHECK: omp.workshare {
+  omp.workshare {
+    "test.payload"() : () -> ()
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare_nowait
+func.func @omp_workshare_nowait() {
+  // CHECK: omp.workshare nowait {
+  omp.workshare nowait {
+    "test.payload"() : () -> ()
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare_multiple_blocks
+func.func @omp_workshare_multiple_blocks() {
+  // CHECK: omp.workshare {
+  omp.workshare {
+    cf.br ^bb2
+    ^bb2:
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare_loop_wrapper
+func.func @omp_workshare_loop_wrapper(%idx : index) {
+  // CHECK-NEXT: omp.workshare_loop_wrapper
+  omp.workshare_loop_wrapper {
+    // CHECK-NEXT: omp.loop_nest
+    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+      omp.yield
+    }
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare_loop_wrapper_attrs
+func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
+  // CHECK-NEXT: omp.workshare_loop_wrapper {
+  omp.workshare_loop_wrapper {
+    // CHECK-NEXT: omp.loop_nest
+    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+      omp.yield
+    }
+    omp.terminator
+  // CHECK: } {attr_in_dict}
+  } {attr_in_dict}
+  return
+}

>From 72e19faffb72d3aafed907efd0613416d06cbce0 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 14:42:35 +0900
Subject: [PATCH 10/36] Add verification and descriptions

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 10 +++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 14 +++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 42 +++++++++++++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 34 +++++++++------
 4 files changed, 87 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 0f29e911cb2f2..74497def8fd1a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -301,6 +301,10 @@ def WorkshareOp : OpenMP_Op<"workshare", traits = [
     block into separate units of work, and causes the threads of the team to
     share the work such that each unit is executed only once by one thread, in
     the context of its implicit task
+
+    This operation is used for the intermediate representation of the workshare
+    block before the work gets divided between the threads. See the flang
+    LowerWorkshare pass for details.
   }] # clausesDescription;
 
   let builders = [
@@ -313,10 +317,16 @@ def WorkshareLoopWrapperOp : OpenMP_Op<"workshare_loop_wrapper", traits = [
     RecursiveMemoryEffects, SingleBlock
   ], singleRegion = true> {
   let summary = "contains loop nests to be parallelized by workshare";
+  let description = [{
+    This operation wraps a loop nest that is marked for dividing into units of
+    work by an encompassing omp.workshare operation.
+  }];
+
   let builders = [
     OpBuilder<(ins), [{ build($_builder, $_state, {}); }]>
   ];
   let assemblyFormat = "$region attr-dict";
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6c1b77077bdba..90f9a19ebe32b 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1692,6 +1692,20 @@ void WorkshareOp::build(OpBuilder &builder, OperationState &state,
   WorkshareOp::build(builder, state, clauses.nowait);
 }
 
+//===----------------------------------------------------------------------===//
+// WorkshareLoopWrapperOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkshareLoopWrapperOp::verify() {
+  if (!isWrapper())
+    return emitOpError() << "must be a loop wrapper";
+  if (getNestedWrapper())
+    return emitError() << "nested wrappers not supported";
+  if (!(*this)->getParentOfType<WorkshareOp>())
+    return emitError() << "must be nested in an omp.workshare";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 1d1d93f097758..ee7c448c467cf 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2383,3 +2383,45 @@ func.func @masked_arg_count_mismatch(%arg0: i32, %arg1: i32) {
     }) : (i32, i32) -> ()
   return
 }
+
+// -----
+func.func @nested_wrapper(%idx : index) {
+  omp.workshare {
+    // expected-error @below {{nested wrappers not supported}}
+    omp.workshare_loop_wrapper {
+      omp.simd {
+        omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+          omp.yield
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+func.func @not_wrapper() {
+  omp.workshare {
+    // expected-error @below {{must be a loop wrapper}}
+    omp.workshare_loop_wrapper {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+func.func @missing_workshare(%idx : index) {
+  // expected-error @below {{must be nested in an omp.workshare}}
+  omp.workshare_loop_wrapper {
+    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+      omp.yield
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 981e3fbb0306b..24363334b4cb5 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2826,11 +2826,15 @@ func.func @omp_workshare_multiple_blocks() {
 
 // CHECK-LABEL: func @omp_workshare_loop_wrapper
 func.func @omp_workshare_loop_wrapper(%idx : index) {
-  // CHECK-NEXT: omp.workshare_loop_wrapper
-  omp.workshare_loop_wrapper {
-    // CHECK-NEXT: omp.loop_nest
-    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
-      omp.yield
+  // CHECK-NEXT: omp.workshare {
+  omp.workshare {
+    // CHECK-NEXT: omp.workshare_loop_wrapper
+    omp.workshare_loop_wrapper {
+      // CHECK-NEXT: omp.loop_nest
+      omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+        omp.yield
+      }
+      omp.terminator
     }
     omp.terminator
   }
@@ -2839,14 +2843,18 @@ func.func @omp_workshare_loop_wrapper(%idx : index) {
 
 // CHECK-LABEL: func @omp_workshare_loop_wrapper_attrs
 func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
-  // CHECK-NEXT: omp.workshare_loop_wrapper {
-  omp.workshare_loop_wrapper {
-    // CHECK-NEXT: omp.loop_nest
-    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
-      omp.yield
-    }
+  // CHECK-NEXT: omp.workshare {
+  omp.workshare {
+    // CHECK-NEXT: omp.workshare_loop_wrapper {
+    omp.workshare_loop_wrapper {
+      // CHECK-NEXT: omp.loop_nest
+      omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+        omp.yield
+      }
+      omp.terminator
+    // CHECK: } {attr_in_dict}
+    } {attr_in_dict}
     omp.terminator
-  // CHECK: } {attr_in_dict}
-  } {attr_in_dict}
+  }
   return
 }

>From 63d49e4dcd128b470ee77006c594673203dd2df2 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:11:47 +0900
Subject: [PATCH 11/36] [flang][omp] Emit omp.workshare in frontend

---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 30 ++++++++++++++++++++++++++----
 1 file changed, 26 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 2b1839b5270d4..f7bc565ea8cbc 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1270,6 +1270,15 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter,
       loc, llvm::omp::Directive::OMPD_taskwait);
 }
 
+static void genWorkshareClauses(lower::AbstractConverter &converter,
+                                semantics::SemanticsContext &semaCtx,
+                                lower::StatementContext &stmtCtx,
+                                const List<Clause> &clauses, mlir::Location loc,
+                                mlir::omp::WorkshareOperands &clauseOps) {
+  ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processNowait(clauseOps);
+}
+
 static void genTeamsClauses(lower::AbstractConverter &converter,
                             semantics::SemanticsContext &semaCtx,
                             lower::StatementContext &stmtCtx,
@@ -1890,6 +1899,22 @@ genTaskyieldOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   return converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(loc);
 }
 
+static mlir::omp::WorkshareOp
+genWorkshareOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+           semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+           mlir::Location loc, const ConstructQueue &queue,
+           ConstructQueue::iterator item) {
+  lower::StatementContext stmtCtx;
+  mlir::omp::WorkshareOperands clauseOps;
+  genWorkshareClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
+
+  return genOpWithBody<mlir::omp::WorkshareOp>(
+      OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                        llvm::omp::Directive::OMPD_workshare)
+          .setClauses(&item->clauses),
+      queue, item, clauseOps);
+}
+
 static mlir::omp::TeamsOp
 genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
            semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
@@ -2249,10 +2274,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                   llvm::omp::getOpenMPDirectiveName(dir) + ")");
   // case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_workshare:
-    // FIXME: Workshare is not a commonly used OpenMP construct, an
-    // implementation for this feature will come later. For the codes
-    // that use this construct, add a single construct for now.
-    genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genWorkshareOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
 
   // Composite constructs

>From 621b01775171a4718fa405f201b58c3dca005e5a Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:02:37 +0900
Subject: [PATCH 12/36] Fix lower test for workshare

---
 flang/test/Lower/OpenMP/workshare.f90 | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/flang/test/Lower/OpenMP/workshare.f90 b/flang/test/Lower/OpenMP/workshare.f90
index 1e11677a15e1f..8e771952f5b6d 100644
--- a/flang/test/Lower/OpenMP/workshare.f90
+++ b/flang/test/Lower/OpenMP/workshare.f90
@@ -6,7 +6,7 @@ subroutine sb1(arr)
   integer :: arr(:)
 !CHECK: omp.parallel  {
   !$omp parallel
-!CHECK: omp.single  {
+!CHECK: omp.workshare {
   !$omp workshare
     arr = 0
   !$omp end workshare
@@ -20,7 +20,7 @@ subroutine sb2(arr)
   integer :: arr(:)
 !CHECK: omp.parallel  {
   !$omp parallel
-!CHECK: omp.single nowait {
+!CHECK: omp.workshare nowait {
   !$omp workshare
     arr = 0
   !$omp end workshare nowait
@@ -33,7 +33,7 @@ subroutine sb2(arr)
 subroutine sb3(arr)
   integer :: arr(:)
 !CHECK: omp.parallel  {
-!CHECK: omp.single  {
+!CHECK: omp.workshare  {
   !$omp parallel workshare
     arr = 0
   !$omp end parallel workshare

>From 5e470922405b735d63b4aded76450cc52e94e003 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:12:34 +0900
Subject: [PATCH 13/36] [flang] Introduce ws loop nest generation for HLFIR
 lowering

---
 .../flang/Optimizer/Builder/HLFIRTools.h      | 12 +++--
 flang/lib/Lower/ConvertCall.cpp               |  2 +-
 flang/lib/Lower/OpenMP/ReductionProcessor.cpp |  4 +-
 flang/lib/Optimizer/Builder/HLFIRTools.cpp    | 52 ++++++++++++++-----
 .../HLFIR/Transforms/BufferizeHLFIR.cpp       |  3 +-
 .../LowerHLFIROrderedAssignments.cpp          | 30 +++++------
 .../Transforms/OptimizedBufferization.cpp     |  6 +--
 7 files changed, 69 insertions(+), 40 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 6b41025eea078..14e42c6f358e4 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -357,8 +357,8 @@ hlfir::ElementalOp genElementalOp(
 
 /// Structure to describe a loop nest.
 struct LoopNest {
-  fir::DoLoopOp outerLoop;
-  fir::DoLoopOp innerLoop;
+  mlir::Operation *outerOp;
+  mlir::Block *body;
   llvm::SmallVector<mlir::Value> oneBasedIndices;
 };
 
@@ -366,11 +366,13 @@ struct LoopNest {
 /// \p isUnordered specifies whether the loops in the loop nest
 /// are unordered.
 LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
-                     mlir::ValueRange extents, bool isUnordered = false);
+                     mlir::ValueRange extents, bool isUnordered = false,
+                     bool emitWsLoop = false);
 inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
-                            mlir::Value shape, bool isUnordered = false) {
+                            mlir::Value shape, bool isUnordered = false,
+                            bool emitWsLoop = false) {
   return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
-                     isUnordered);
+                     isUnordered, emitWsLoop);
 }
 
 /// Inline the body of an hlfir.elemental at the current insertion point
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index fd873f55dd844..0689d6e033dd9 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -2128,7 +2128,7 @@ class ElementalCallBuilder {
           hlfir::genLoopNest(loc, builder, shape, !mustBeOrdered);
       mlir::ValueRange oneBasedIndices = loopNest.oneBasedIndices;
       auto insPt = builder.saveInsertionPoint();
-      builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+      builder.setInsertionPointToStart(loopNest.body);
       callContext.stmtCtx.pushScope();
       for (auto &preparedActual : loweredActuals)
         if (preparedActual)
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c3c1f363033c2..72a90dd0d6f29 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -375,7 +375,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
   // know this won't miss any opportuinties for clever elemental inlining
   hlfir::LoopNest nest = hlfir::genLoopNest(
       loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
-  builder.setInsertionPointToStart(nest.innerLoop.getBody());
+  builder.setInsertionPointToStart(nest.body);
   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
@@ -389,7 +389,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
       builder, loc, redId, refTy, lhsEle, rhsEle);
   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
 
-  builder.setInsertionPointAfter(nest.outerLoop);
+  builder.setInsertionPointAfter(nest.outerOp);
   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
 }
 
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 8d0ae2f195178..cd07cb741eb4b 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <optional>
 
 // Return explicit extents. If the base is a fir.box, this won't read it to
@@ -855,26 +856,51 @@ mlir::Value hlfir::inlineElementalOp(
 
 hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
                                    fir::FirOpBuilder &builder,
-                                   mlir::ValueRange extents, bool isUnordered) {
+                                   mlir::ValueRange extents, bool isUnordered,
+                                   bool emitWsLoop) {
   hlfir::LoopNest loopNest;
   assert(!extents.empty() && "must have at least one extent");
-  auto insPt = builder.saveInsertionPoint();
+  mlir::OpBuilder::InsertionGuard guard(builder);
   loopNest.oneBasedIndices.assign(extents.size(), mlir::Value{});
   // Build loop nest from column to row.
   auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
   mlir::Type indexType = builder.getIndexType();
-  unsigned dim = extents.size() - 1;
-  for (auto extent : llvm::reverse(extents)) {
-    auto ub = builder.createConvert(loc, indexType, extent);
-    loopNest.innerLoop =
-        builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
-    builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
-    // Reverse the indices so they are in column-major order.
-    loopNest.oneBasedIndices[dim--] = loopNest.innerLoop.getInductionVar();
-    if (!loopNest.outerLoop)
-      loopNest.outerLoop = loopNest.innerLoop;
+  if (emitWsLoop) {
+    auto wsloop = builder.create<mlir::omp::WsloopOp>(
+        loc, mlir::ArrayRef<mlir::NamedAttribute>());
+    loopNest.outerOp = wsloop;
+    builder.createBlock(&wsloop.getRegion());
+    mlir::omp::LoopNestOperands lnops;
+    lnops.loopInclusive = builder.getUnitAttr();
+    for (auto extent : llvm::reverse(extents)) {
+      lnops.loopLowerBounds.push_back(one);
+      lnops.loopUpperBounds.push_back(extent);
+      lnops.loopSteps.push_back(one);
+    }
+    auto lnOp = builder.create<mlir::omp::LoopNestOp>(loc, lnops);
+    builder.create<mlir::omp::TerminatorOp>(loc);
+    mlir::Block *block = builder.createBlock(&lnOp.getRegion());
+    for (auto extent : llvm::reverse(extents))
+      block->addArgument(extent.getType(), extent.getLoc());
+    loopNest.body = block;
+    builder.create<mlir::omp::YieldOp>(loc);
+    for (unsigned dim = 0; dim < extents.size(); dim++)
+      loopNest.oneBasedIndices[extents.size() - dim - 1] =
+          lnOp.getRegion().front().getArgument(dim);
+  } else {
+    unsigned dim = extents.size() - 1;
+    for (auto extent : llvm::reverse(extents)) {
+      auto ub = builder.createConvert(loc, indexType, extent);
+      auto doLoop =
+          builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
+      loopNest.body = doLoop.getBody();
+      builder.setInsertionPointToStart(loopNest.body);
+      // Reverse the indices so they are in column-major order.
+      loopNest.oneBasedIndices[dim--] = doLoop.getInductionVar();
+      if (!loopNest.outerOp)
+        loopNest.outerOp = doLoop;
+    }
   }
-  builder.restoreInsertionPoint(insPt);
   return loopNest;
 }
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index a70a6b388c4b1..b608677c52631 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace hlfir {
@@ -793,7 +794,7 @@ struct ElementalOpConversion
     hlfir::LoopNest loopNest =
         hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
     auto insPt = builder.saveInsertionPoint();
-    builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+    builder.setInsertionPointToStart(loopNest.body);
     auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
                                           loopNest.oneBasedIndices);
     hlfir::Entity elementValue(yield.getElementValue());
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
index 85dd517cb5791..645abf65d10a3 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
@@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
       // if the LHS is not).
       mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
       elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
-      builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
+      builder.setInsertionPointToStart(elementalLoopNest->body);
       lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
                                       elementalLoopNest->oneBasedIndices);
       rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
@@ -484,7 +484,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
     for (auto &cleanupConversion : argConversionCleanups)
       cleanupConversion();
     if (elementalLoopNest)
-      builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
+      builder.setInsertionPointAfter(elementalLoopNest->outerOp);
   } else {
     // TODO: preserve allocatable assignment aspects for forall once
     // they are conveyed in hlfir.region_assign.
@@ -493,7 +493,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
   generateCleanupIfAny(loweredLhs.elementalCleanup);
   if (loweredLhs.vectorSubscriptLoopNest)
     builder.setInsertionPointAfter(
-        loweredLhs.vectorSubscriptLoopNest->outerLoop);
+        loweredLhs.vectorSubscriptLoopNest->outerOp);
   generateCleanupIfAny(oldRhsYield);
   generateCleanupIfAny(loweredLhs.nonElementalCleanup);
 }
@@ -518,8 +518,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
       hlfir::Entity savedMask{maybeSaved->first};
       mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
       whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
-      constructStack.push_back(whereLoopNest->outerLoop.getOperation());
-      builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
+      constructStack.push_back(whereLoopNest->outerOp);
+      builder.setInsertionPointToStart(whereLoopNest->body);
       mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
                                             whereLoopNest->oneBasedIndices);
       generateMaskIfOp(cdt);
@@ -527,7 +527,7 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
         // If this is the same run as the one that saved the value, the clean-up
         // was left-over to be done now.
         auto insertionPoint = builder.saveInsertionPoint();
-        builder.setInsertionPointAfter(whereLoopNest->outerLoop);
+        builder.setInsertionPointAfter(whereLoopNest->outerOp);
         generateCleanupIfAny(maybeSaved->second);
         builder.restoreInsertionPoint(insertionPoint);
       }
@@ -539,8 +539,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
     mask.generateNoneElementalPart(builder, mapper);
     mlir::Value shape = mask.generateShape(builder, mapper);
     whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
-    constructStack.push_back(whereLoopNest->outerLoop.getOperation());
-    builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
+    constructStack.push_back(whereLoopNest->outerOp);
+    builder.setInsertionPointToStart(whereLoopNest->body);
     mlir::Value cdt = generateMaskedEntity(mask);
     generateMaskIfOp(cdt);
     return;
@@ -754,7 +754,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
       loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
           loc, builder, loweredLhs.vectorSubscriptShape.value());
       builder.setInsertionPointToStart(
-          loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
+          loweredLhs.vectorSubscriptLoopNest->body);
     }
     loweredLhs.lhs = temp->second.fetch(loc, builder);
     return loweredLhs;
@@ -772,7 +772,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
         hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
                            !elementalAddrLhs.isOrdered());
     builder.setInsertionPointToStart(
-        loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
+        loweredLhs.vectorSubscriptLoopNest->body);
     mapper.map(elementalAddrLhs.getIndices(),
                loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
     for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
@@ -798,11 +798,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
   if (!maskedExpr.noneElementalPartWasGenerated) {
     // Generate none elemental part before the where loops (but inside the
     // current forall loops if any).
-    builder.setInsertionPoint(whereLoopNest->outerLoop);
+    builder.setInsertionPoint(whereLoopNest->outerOp);
     maskedExpr.generateNoneElementalPart(builder, mapper);
   }
   // Generate the none elemental part cleanup after the where loops.
-  builder.setInsertionPointAfter(whereLoopNest->outerLoop);
+  builder.setInsertionPointAfter(whereLoopNest->outerOp);
   maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
   // Generate the value of the current element for the masked expression
   // at the current insertion point (inside the where loops, and any fir.if
@@ -1242,7 +1242,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
   LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
   fir::factory::TemporaryStorage *temp = nullptr;
   if (loweredLhs.vectorSubscriptLoopNest)
-    constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
+    constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
   if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
     // Vector subscripted entity for which the shape must also be saved on top
     // of the element addresses (e.g. the shape may change in each forall
@@ -1265,7 +1265,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
     // subscripted LHS.
     auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
     auto insertionPoint = builder.saveInsertionPoint();
-    builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
+    builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
     vectorTmp.pushShape(loc, builder, shape);
     builder.restoreInsertionPoint(insertionPoint);
   } else {
@@ -1291,7 +1291,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
   if (loweredLhs.vectorSubscriptLoopNest) {
     constructStack.pop_back();
     builder.setInsertionPointAfter(
-        loweredLhs.vectorSubscriptLoopNest->outerLoop);
+        loweredLhs.vectorSubscriptLoopNest->outerOp);
   }
 }
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index c5b809514c54c..c4aed6b79df92 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -483,7 +483,7 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
   // hlfir.elemental region inside the inner loop
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
                                         loopNest.oneBasedIndices);
   hlfir::Entity elementValue{yield.getElementValue()};
@@ -554,7 +554,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
       hlfir::getIndexExtents(loc, builder, shape);
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto arrayElement =
       hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
   builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
@@ -649,7 +649,7 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
       hlfir::getIndexExtents(loc, builder, shape);
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto rhsArrayElement =
       hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
   rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);

>From d7f1a0c3bb2f0a6ba6f8fff5a8fd84c3061e1984 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:08:34 +0900
Subject: [PATCH 14/36] Emit loop nests in a custom wrapper

---
 flang/include/flang/Optimizer/Builder/HLFIRTools.h |  6 +++---
 flang/lib/Optimizer/Builder/HLFIRTools.cpp         | 11 +++++------
 2 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 14e42c6f358e4..6987471957218 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -367,12 +367,12 @@ struct LoopNest {
 /// are unordered.
 LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
                      mlir::ValueRange extents, bool isUnordered = false,
-                     bool emitWsLoop = false);
+                     bool emitWorkshareLoop = false);
 inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
                             mlir::Value shape, bool isUnordered = false,
-                            bool emitWsLoop = false) {
+                            bool emitWorkshareLoop = false) {
   return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
-                     isUnordered, emitWsLoop);
+                     isUnordered, emitWorkshareLoop);
 }
 
 /// Inline the body of an hlfir.elemental at the current insertion point
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index cd07cb741eb4b..91b1b3d774a01 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -857,7 +857,7 @@ mlir::Value hlfir::inlineElementalOp(
 hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
                                    fir::FirOpBuilder &builder,
                                    mlir::ValueRange extents, bool isUnordered,
-                                   bool emitWsLoop) {
+                                   bool emitWorkshareLoop) {
   hlfir::LoopNest loopNest;
   assert(!extents.empty() && "must have at least one extent");
   mlir::OpBuilder::InsertionGuard guard(builder);
@@ -865,11 +865,10 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
   // Build loop nest from column to row.
   auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
   mlir::Type indexType = builder.getIndexType();
-  if (emitWsLoop) {
-    auto wsloop = builder.create<mlir::omp::WsloopOp>(
-        loc, mlir::ArrayRef<mlir::NamedAttribute>());
-    loopNest.outerOp = wsloop;
-    builder.createBlock(&wsloop.getRegion());
+  if (emitWorkshareLoop) {
+    auto wslw = builder.create<mlir::omp::WorkshareLoopWrapperOp>(loc);
+    loopNest.outerOp = wslw;
+    builder.createBlock(&wslw.getRegion());
     mlir::omp::LoopNestOperands lnops;
     lnops.loopInclusive = builder.getUnitAttr();
     for (auto extent : llvm::reverse(extents)) {

>From 78e72f9c7ccbb9c9e9f9f989ff30bbcccb6c0fb5 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 22:05:47 +0900
Subject: [PATCH 15/36] Only emit unordered loops as omp loops

---
 flang/lib/Optimizer/Builder/HLFIRTools.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 91b1b3d774a01..333331378841e 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -858,6 +858,7 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
                                    fir::FirOpBuilder &builder,
                                    mlir::ValueRange extents, bool isUnordered,
                                    bool emitWorkshareLoop) {
+  emitWorkshareLoop = emitWorkshareLoop && isUnordered;
   hlfir::LoopNest loopNest;
   assert(!extents.empty() && "must have at least one extent");
   mlir::OpBuilder::InsertionGuard guard(builder);

>From 082b89259a7a4fe9d189a5c66202bdea22a6589e Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 17:16:22 +0900
Subject: [PATCH 16/36] Fix uninitialized memory bug in genLoopNest

---
 flang/include/flang/Optimizer/Builder/HLFIRTools.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 6987471957218..f073f494b3fb2 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -357,8 +357,8 @@ hlfir::ElementalOp genElementalOp(
 
 /// Structure to describe a loop nest.
 struct LoopNest {
-  mlir::Operation *outerOp;
-  mlir::Block *body;
+  mlir::Operation *outerOp = nullptr;
+  mlir::Block *body = nullptr;
   llvm::SmallVector<mlir::Value> oneBasedIndices;
 };
 

>From db7ff551932c2ff40cfd085b063825f81beaac4b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 22:06:55 +0900
Subject: [PATCH 17/36] [flang] Lower omp.workshare to other omp constructs

---
 flang/include/flang/Optimizer/OpenMP/Passes.h |   2 +
 .../include/flang/Optimizer/OpenMP/Passes.td  |   4 +
 flang/include/flang/Tools/CLOptions.inc       |   1 +
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   1 +
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 259 ++++++++++++++++++
 .../Transforms/OpenMP/lower-workshare.mlir    |  81 ++++++
 .../Transforms/OpenMP/lower-workshare5.mlir   |  42 +++
 7 files changed, 390 insertions(+)
 create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare.mlir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare5.mlir

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.h b/flang/include/flang/Optimizer/OpenMP/Passes.h
index 403d79667bf44..11fa4e59f891e 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.h
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.h
@@ -25,6 +25,8 @@ namespace flangomp {
 #define GEN_PASS_REGISTRATION
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
 
+bool shouldUseWorkshareLowering(mlir::Operation *op);
+
 } // namespace flangomp
 
 #endif // FORTRAN_OPTIMIZER_OPENMP_PASSES_H
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 395178e26a576..1c9d75d8cfaa1 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -37,4 +37,8 @@ def FunctionFiltering : Pass<"omp-function-filtering"> {
   ];
 }
 
+def LowerWorkshare : Pass<"lower-workshare"> {
+  let summary = "Lower workshare construct";
+}
+
 #endif //FORTRAN_OPTIMIZER_OPENMP_PASSES
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 05b2f31711add..a565effebfa92 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,6 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
+  pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index a8984d256b8f6..2a38f157a851c 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -4,6 +4,7 @@ add_flang_library(FlangOpenMPTransforms
   FunctionFiltering.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
+  LowerWorkshare.cpp
 
   DEPENDS
   FIRDialect
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
new file mode 100644
index 0000000000000..40975552d1fe3
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -0,0 +1,259 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Lower omp workshare construct.
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/iterator_range.h"
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKSHARE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workshare"
+
+using namespace mlir;
+
+namespace flangomp {
+bool shouldUseWorkshareLowering(Operation *op) {
+  auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
+  if (!workshare)
+    return false;
+  return workshare->getParentOfType<omp::ParallelOp>();
+}
+} // namespace flangomp
+
+namespace {
+
+struct SingleRegion {
+  Block::iterator begin, end;
+};
+
+static bool isSupportedByFirAlloca(Type ty) {
+  return !isa<fir::ReferenceType>(ty);
+}
+
+static bool isSafeToParallelize(Operation *op) {
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface) {
+    return false;
+  }
+  interface.getEffects(effects);
+  if (effects.empty())
+    return true;
+
+  return false;
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+///   %a = fir.allocmem
+///   omp.wsloop {}
+///   fir.call Assign %b %a
+///   fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// omp.single {
+///   %a = fir.allocmem
+///   fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.wsloop {}
+/// omp.single {
+///   fir.call Assign %b %a_reloaded
+///   fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which need
+/// to be accessed in all threads in the closest omp.parallel
+///
+/// TODO currently we need to be able to access the encompassing omp.parallel so
+/// that we can allocate temporaries accessible by all threads outside of it.
+/// In case we do not find it, we fall back to converting the omp.workshare to
+/// omp.single.
+/// To better handle this we should probably enable yielding values out of an
+/// omp.single which will be supported by the omp runtime.
+void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+  assert(wsOp.getRegion().getBlocks().size() == 1);
+
+  Location loc = wsOp->getLoc();
+
+  omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
+  if (!parallelOp) {
+    wsOp.emitWarning("cannot handle workshare, converting to single");
+    Operation *terminator = wsOp.getRegion().front().getTerminator();
+    wsOp->getBlock()->getOperations().splice(
+        wsOp->getIterator(), wsOp.getRegion().front().getOperations());
+    terminator->erase();
+    return;
+  }
+
+  OpBuilder allocBuilder(parallelOp);
+  OpBuilder rootBuilder(wsOp);
+  IRMapping rootMapping;
+
+  omp::SingleOp singleOp = nullptr;
+
+  auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
+                              IRMapping singleMapping) {
+    if (auto reloaded = rootMapping.lookupOrNull(v))
+      return;
+    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
+    Type ty = v.getType();
+    Value alloc, reloaded;
+    if (isSupportedByFirAlloca(ty)) {
+      alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
+      singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+      reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
+    } else {
+      auto one = allocBuilder.create<LLVM::ConstantOp>(
+          loc, allocBuilder.getI32Type(), 1);
+      alloc =
+          allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
+      Value toStore = singleBuilder
+                          .create<UnrealizedConversionCastOp>(
+                              loc, llvmPtrTy, singleMapping.lookup(v))
+                          .getResult(0);
+      singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
+      reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
+      reloaded =
+          rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
+              .getResult(0);
+    }
+    rootMapping.map(v, reloaded);
+  };
+
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
+    IRMapping singleMapping = rootMapping;
+
+    for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
+      singleBuilder.clone(op, singleMapping);
+      if (isSafeToParallelize(&op)) {
+        rootBuilder.clone(op, rootMapping);
+      } else {
+        // Prepare reloaded values for results of operations that cannot be
+        // safely parallelized and which are used after the region `sr`
+        for (auto res : op.getResults()) {
+          for (auto &use : res.getUses()) {
+            Operation *user = use.getOwner();
+            while (user->getParentOp() != wsOp)
+              user = user->getParentOp();
+            if (!user->isBeforeInBlock(&*sr.end)) {
+              // We need to reload
+              mapReloadedValue(use.get(), singleBuilder, singleMapping);
+            }
+          }
+        }
+      }
+    }
+    singleBuilder.create<omp::TerminatorOp>(loc);
+  };
+
+  Block *wsBlock = &wsOp.getRegion().front();
+  assert(wsBlock->getTerminator()->getNumOperands() == 0);
+  Operation *terminator = wsBlock->getTerminator();
+
+  SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
+
+  auto it = wsBlock->begin();
+  auto getSingleRegion = [&]() {
+    if (&*it == terminator)
+      return false;
+    if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
+      regions.push_back(pop);
+      it++;
+      return true;
+    }
+    SingleRegion sr;
+    sr.begin = it;
+    while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
+      it++;
+    sr.end = it;
+    assert(sr.begin != sr.end);
+    regions.push_back(sr);
+    return true;
+  };
+  while (getSingleRegion())
+    ;
+
+  for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
+    bool isLast = i + 1 == regions.size();
+    if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
+      omp::SingleOperands singleOperands;
+      if (isLast)
+        singleOperands.nowait = rootBuilder.getUnitAttr();
+      singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+      OpBuilder singleBuilder(singleOp);
+      singleBuilder.createBlock(&singleOp.getRegion());
+      moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
+    } else {
+      rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
+      if (!isLast)
+        rootBuilder.create<omp::BarrierOp>(loc);
+    }
+  }
+
+  if (!wsOp.getNowait())
+    rootBuilder.create<omp::BarrierOp>(loc);
+
+  wsOp->erase();
+
+  return;
+}
+
+class LowerWorksharePass
+    : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
+public:
+  void runOnOperation() override {
+    SmallPtrSet<Operation *, 8> parents;
+    getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
+      Operation *isolatedParent =
+          wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
+      parents.insert(isolatedParent);
+
+      lowerWorkshare(wsOp);
+    });
+
+    // Do folding
+    for (Operation *isolatedParent : parents) {
+      RewritePatternSet patterns(&getContext());
+      GreedyRewriteConfig config;
+      // prevent the pattern driver form merging blocks
+      config.enableRegionSimplification =
+          mlir::GreedySimplifyRegionLevel::Disabled;
+      if (failed(applyPatternsAndFoldGreedily(isolatedParent,
+                                              std::move(patterns), config))) {
+        emitError(isolatedParent->getLoc(), "error in lower workshare\n");
+        signalPassFailure();
+      }
+    }
+  }
+};
+} // namespace
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
new file mode 100644
index 0000000000000..a8d36443f08bd
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -0,0 +1,81 @@
+// RUN: fir-opt --lower-workshare %s | FileCheck %s
+
+module {
+// CHECK-LABEL:   func.func @simple(
+// CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
+// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           omp.parallel {
+// CHECK:             omp.single {
+// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]]#0 : !fir.ref<!fir.array<42xi32>> to !llvm.ptr
+// CHECK:               llvm.store %[[VAL_9]], %[[VAL_5]] : !llvm.ptr, !llvm.ptr
+// CHECK:               %[[VAL_10:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               fir.store %[[VAL_11]]#0 to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_12:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.ptr
+// CHECK:             %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.ptr to !fir.ref<!fir.array<42xi32>>
+// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_2]] : i32
+// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_14]] (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                 omp.yield
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.barrier
+// CHECK:             omp.single nowait {
+// CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.barrier
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+  func.func @simple(%arg0: !fir.ref<!fir.array<42xi32>>) {
+    omp.parallel {
+      omp.workshare {
+        %c42 = arith.constant 42 : index
+        %c1_i32 = arith.constant 1 : i32
+        %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+        %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+        %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+        %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+        %true = arith.constant true
+        %c1 = arith.constant 1 : index
+        omp.wsloop {
+          omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+            %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+            %8 = fir.load %7 : !fir.ref<i32>
+            %9 = arith.subi %8, %c1_i32 : i32
+            %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+            hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+            omp.yield
+          }
+          omp.terminator
+        }
+        %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+        fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+        omp.terminator
+      }
+      omp.terminator
+    }
+    return
+  }
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workshare5.mlir b/flang/test/Transforms/OpenMP/lower-workshare5.mlir
new file mode 100644
index 0000000000000..177f8aa8f86c7
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare5.mlir
@@ -0,0 +1,42 @@
+// XFAIL: *
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// TODO we can lower these but we have no guarantee that the parent of
+// omp.workshare supports multi-block regions, thus we fail for now.
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+    ^bb1:
+      %c1 = arith.constant 1 : i32
+      cf.br ^bb3(%c1: i32)
+    ^bb3(%arg1: i32):
+      "test.test2"(%arg1) : (i32) -> ()
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+    ^bb1:
+      %c1 = arith.constant 1 : i32
+      cf.br ^bb3(%c1: i32)
+    ^bb2:
+      "test.test2"(%r) : (i32) -> ()
+      omp.terminator
+    ^bb3(%arg1: i32):
+      %r = "test.test2"(%arg1) : (i32) -> i32
+      cf.br ^bb2
+    }
+    omp.terminator
+  }
+  return
+}

>From 342433e135af020a68939200ca4d89cbc7d499dd Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:41:09 +0900
Subject: [PATCH 18/36] Change to workshare loop wrapper op

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 24 ++++++++++++-------
 .../Transforms/OpenMP/lower-workshare.mlir    |  5 ++--
 2 files changed, 18 insertions(+), 11 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 40975552d1fe3..cb342b60de4e8 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -21,6 +21,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/iterator_range.h"
 
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
 #include <variant>
 
 namespace flangomp {
@@ -73,7 +74,7 @@ static bool isSafeToParallelize(Operation *op) {
 ///
 /// omp.workshare {
 ///   %a = fir.allocmem
-///   omp.wsloop {}
+///   omp.workshare_loop_wrapper {}
 ///   fir.call Assign %b %a
 ///   fir.freemem %a
 /// }
@@ -85,7 +86,7 @@ static bool isSafeToParallelize(Operation *op) {
 ///   fir.store %a %tmp
 /// }
 /// %a_reloaded = fir.load %tmp
-/// omp.wsloop {}
+/// omp.workshare_loop_wrapper {}
 /// omp.single {
 ///   fir.call Assign %b %a_reloaded
 ///   fir.freemem %a_reloaded
@@ -180,20 +181,20 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
   assert(wsBlock->getTerminator()->getNumOperands() == 0);
   Operation *terminator = wsBlock->getTerminator();
 
-  SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
+  SmallVector<std::variant<SingleRegion, omp::WorkshareLoopWrapperOp>> regions;
 
   auto it = wsBlock->begin();
   auto getSingleRegion = [&]() {
     if (&*it == terminator)
       return false;
-    if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
+    if (auto pop = dyn_cast<omp::WorkshareLoopWrapperOp>(&*it)) {
       regions.push_back(pop);
       it++;
       return true;
     }
     SingleRegion sr;
     sr.begin = it;
-    while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
+    while (&*it != terminator && !isa<omp::WorkshareLoopWrapperOp>(&*it))
       it++;
     sr.end = it;
     assert(sr.begin != sr.end);
@@ -214,9 +215,16 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
       singleBuilder.createBlock(&singleOp.getRegion());
       moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
     } else {
-      rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
-      if (!isLast)
-        rootBuilder.create<omp::BarrierOp>(loc);
+      omp::WsloopOperands wsloopOperands;
+      if (isLast)
+        wsloopOperands.nowait = rootBuilder.getUnitAttr();
+      auto wsloop =
+          rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
+      auto wslw = std::get<omp::WorkshareLoopWrapperOp>(loopOrSingle);
+      auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
+          rootBuilder.clone(*wslw, rootMapping));
+      wsloop.getRegion().takeBody(clonedWslw.getRegion());
+      clonedWslw->erase();
     }
   }
 
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index a8d36443f08bd..cb5791d35916a 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -34,7 +34,6 @@ module {
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             omp.barrier
 // CHECK:             omp.single nowait {
 // CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
 // CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
@@ -56,7 +55,7 @@ module {
         %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
         %true = arith.constant true
         %c1 = arith.constant 1 : index
-        omp.wsloop {
+        "omp.workshare_loop_wrapper"() ({
           omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
             %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
             %8 = fir.load %7 : !fir.ref<i32>
@@ -66,7 +65,7 @@ module {
             omp.yield
           }
           omp.terminator
-        }
+        }) : () -> ()
         %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
         %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
         %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>

>From 4a0408128d3b3543394871e7b47c6af56e4e2375 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:47:27 +0900
Subject: [PATCH 19/36] Move single op declaration

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index cb342b60de4e8..2322d2acbc013 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -120,8 +120,6 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
   OpBuilder rootBuilder(wsOp);
   IRMapping rootMapping;
 
-  omp::SingleOp singleOp = nullptr;
-
   auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
                               IRMapping singleMapping) {
     if (auto reloaded = rootMapping.lookupOrNull(v))
@@ -210,7 +208,8 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
       omp::SingleOperands singleOperands;
       if (isLast)
         singleOperands.nowait = rootBuilder.getUnitAttr();
-      singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+      omp::SingleOp singleOp =
+          rootBuilder.create<omp::SingleOp>(loc, singleOperands);
       OpBuilder singleBuilder(singleOp);
       singleBuilder.createBlock(&singleOp.getRegion());
       moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);

>From 97768eab5fe64e9c1ed2c22721ca5ffeff2a8b67 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 17:13:58 +0900
Subject: [PATCH 20/36] Schedule pass properly

---
 flang/include/flang/Tools/CLOptions.inc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index a565effebfa92..b256dba3c9294 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,7 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  pm.addPass(flangomp::createLowerWorkshare());
+  addNestedPassToAllTopLevelOperations(pm, flangomp::createLowerWorkshare);
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed

>From 89271ff74dbaf707705a553e4e18194de50093db Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 00:30:40 +0900
Subject: [PATCH 21/36] Correctly handle nested nested loop nests to be
 parallelized by workshare

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 256 ++++++++++--------
 1 file changed, 138 insertions(+), 118 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 2322d2acbc013..8e79d1401c01c 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -19,9 +19,14 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/iterator_range.h"
 
+#include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/Visitors.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <variant>
 
 namespace flangomp {
@@ -52,90 +57,40 @@ static bool isSupportedByFirAlloca(Type ty) {
   return !isa<fir::ReferenceType>(ty);
 }
 
-static bool isSafeToParallelize(Operation *op) {
-  if (isa<fir::DeclareOp>(op))
-    return true;
-
-  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
-  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
-  if (!interface) {
-    return false;
-  }
-  interface.getEffects(effects);
-  if (effects.empty())
-    return true;
-
-  return false;
+static bool mustParallelizeOp(Operation *op) {
+  return op
+      ->walk(
+          [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
+      .wasInterrupted();
 }
 
-/// Lowers workshare to a sequence of single-thread regions and parallel loops
-///
-/// For example:
-///
-/// omp.workshare {
-///   %a = fir.allocmem
-///   omp.workshare_loop_wrapper {}
-///   fir.call Assign %b %a
-///   fir.freemem %a
-/// }
-///
-/// becomes
-///
-/// omp.single {
-///   %a = fir.allocmem
-///   fir.store %a %tmp
-/// }
-/// %a_reloaded = fir.load %tmp
-/// omp.workshare_loop_wrapper {}
-/// omp.single {
-///   fir.call Assign %b %a_reloaded
-///   fir.freemem %a_reloaded
-/// }
-///
-/// Note that we allocate temporary memory for values in omp.single's which need
-/// to be accessed in all threads in the closest omp.parallel
-///
-/// TODO currently we need to be able to access the encompassing omp.parallel so
-/// that we can allocate temporaries accessible by all threads outside of it.
-/// In case we do not find it, we fall back to converting the omp.workshare to
-/// omp.single.
-/// To better handle this we should probably enable yielding values out of an
-/// omp.single which will be supported by the omp runtime.
-void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
-  assert(wsOp.getRegion().getBlocks().size() == 1);
-
-  Location loc = wsOp->getLoc();
+static bool isSafeToParallelize(Operation *op) {
+  return isa<fir::DeclareOp>(op) || isPure(op);
+}
 
-  omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
-  if (!parallelOp) {
-    wsOp.emitWarning("cannot handle workshare, converting to single");
-    Operation *terminator = wsOp.getRegion().front().getTerminator();
-    wsOp->getBlock()->getOperations().splice(
-        wsOp->getIterator(), wsOp.getRegion().front().getOperations());
-    terminator->erase();
-    return;
-  }
-
-  OpBuilder allocBuilder(parallelOp);
-  OpBuilder rootBuilder(wsOp);
-  IRMapping rootMapping;
+static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
+                              IRMapping &rootMapping, Location loc) {
+  Operation *parentOp = sourceRegion.getParentOp();
+  OpBuilder rootBuilder(sourceRegion.getContext());
 
+  // TODO need to copyprivate the alloca's
   auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
                               IRMapping singleMapping) {
+    OpBuilder allocaBuilder(&targetRegion.front().front());
     if (auto reloaded = rootMapping.lookupOrNull(v))
       return;
-    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
+    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
     Value alloc, reloaded;
     if (isSupportedByFirAlloca(ty)) {
-      alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
+      alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
       singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
       reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
     } else {
-      auto one = allocBuilder.create<LLVM::ConstantOp>(
-          loc, allocBuilder.getI32Type(), 1);
+      auto one = allocaBuilder.create<LLVM::ConstantOp>(
+          loc, allocaBuilder.getI32Type(), 1);
       alloc =
-          allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
+          allocaBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
       Value toStore = singleBuilder
                           .create<UnrealizedConversionCastOp>(
                               loc, llvmPtrTy, singleMapping.lookup(v))
@@ -162,9 +117,10 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
         for (auto res : op.getResults()) {
           for (auto &use : res.getUses()) {
             Operation *user = use.getOwner();
-            while (user->getParentOp() != wsOp)
+            while (user->getParentOp() != parentOp)
               user = user->getParentOp();
-            if (!user->isBeforeInBlock(&*sr.end)) {
+            if (!(user->isBeforeInBlock(&*sr.end) &&
+                  sr.begin->isBeforeInBlock(user))) {
               // We need to reload
               mapReloadedValue(use.get(), singleBuilder, singleMapping);
             }
@@ -175,61 +131,125 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
     singleBuilder.create<omp::TerminatorOp>(loc);
   };
 
-  Block *wsBlock = &wsOp.getRegion().front();
-  assert(wsBlock->getTerminator()->getNumOperands() == 0);
-  Operation *terminator = wsBlock->getTerminator();
+  // TODO Need to handle these (clone them) in dominator tree order
+  for (Block &block : sourceRegion) {
+    rootBuilder.createBlock(
+        &targetRegion, {}, block.getArgumentTypes(),
+        llvm::map_to_vector(block.getArguments(),
+                            [](BlockArgument arg) { return arg.getLoc(); }));
+    Operation *terminator = block.getTerminator();
 
-  SmallVector<std::variant<SingleRegion, omp::WorkshareLoopWrapperOp>> regions;
+    SmallVector<std::variant<SingleRegion, Operation *>> regions;
 
-  auto it = wsBlock->begin();
-  auto getSingleRegion = [&]() {
-    if (&*it == terminator)
-      return false;
-    if (auto pop = dyn_cast<omp::WorkshareLoopWrapperOp>(&*it)) {
-      regions.push_back(pop);
-      it++;
+    auto it = block.begin();
+    auto getOneRegion = [&]() {
+      if (&*it == terminator)
+        return false;
+      if (mustParallelizeOp(&*it)) {
+        regions.push_back(&*it);
+        it++;
+        return true;
+      }
+      SingleRegion sr;
+      sr.begin = it;
+      while (&*it != terminator && !mustParallelizeOp(&*it))
+        it++;
+      sr.end = it;
+      assert(sr.begin != sr.end);
+      regions.push_back(sr);
       return true;
+    };
+    while (getOneRegion())
+      ;
+
+    for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
+      bool isLast = i + 1 == regions.size();
+      if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+        omp::SingleOperands singleOperands;
+        if (isLast)
+          singleOperands.nowait = rootBuilder.getUnitAttr();
+        omp::SingleOp singleOp =
+            rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+        OpBuilder singleBuilder(singleOp);
+        singleBuilder.createBlock(&singleOp.getRegion());
+        moveToSingle(std::get<SingleRegion>(opOrSingle), singleBuilder);
+      } else {
+        auto op = std::get<Operation *>(opOrSingle);
+        if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {
+          omp::WsloopOperands wsloopOperands;
+          if (isLast)
+            wsloopOperands.nowait = rootBuilder.getUnitAttr();
+          auto wsloop =
+              rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
+          auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
+              rootBuilder.clone(*wslw, rootMapping));
+          wsloop.getRegion().takeBody(clonedWslw.getRegion());
+          clonedWslw->erase();
+        } else {
+          assert(mustParallelizeOp(op));
+          Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
+          for (auto [region, clonedRegion] :
+               llvm::zip(op->getRegions(), cloned->getRegions()))
+            parallelizeRegion(region, clonedRegion, rootMapping, loc);
+        }
+      }
     }
-    SingleRegion sr;
-    sr.begin = it;
-    while (&*it != terminator && !isa<omp::WorkshareLoopWrapperOp>(&*it))
-      it++;
-    sr.end = it;
-    assert(sr.begin != sr.end);
-    regions.push_back(sr);
-    return true;
-  };
-  while (getSingleRegion())
-    ;
-
-  for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
-    bool isLast = i + 1 == regions.size();
-    if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
-      omp::SingleOperands singleOperands;
-      if (isLast)
-        singleOperands.nowait = rootBuilder.getUnitAttr();
-      omp::SingleOp singleOp =
-          rootBuilder.create<omp::SingleOp>(loc, singleOperands);
-      OpBuilder singleBuilder(singleOp);
-      singleBuilder.createBlock(&singleOp.getRegion());
-      moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
-    } else {
-      omp::WsloopOperands wsloopOperands;
-      if (isLast)
-        wsloopOperands.nowait = rootBuilder.getUnitAttr();
-      auto wsloop =
-          rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
-      auto wslw = std::get<omp::WorkshareLoopWrapperOp>(loopOrSingle);
-      auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
-          rootBuilder.clone(*wslw, rootMapping));
-      wsloop.getRegion().takeBody(clonedWslw.getRegion());
-      clonedWslw->erase();
-    }
+
+    rootBuilder.clone(*block.getTerminator(), rootMapping);
   }
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+///   %a = fir.allocmem
+///   omp.workshare_loop_wrapper {}
+///   fir.call Assign %b %a
+///   fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// omp.single {
+///   %a = fir.allocmem
+///   fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.workshare_loop_wrapper {}
+/// omp.single {
+///   fir.call Assign %b %a_reloaded
+///   fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which need
+/// to be accessed in all threads in the closest omp.parallel
+void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+  Location loc = wsOp->getLoc();
+  IRMapping rootMapping;
+
+  OpBuilder rootBuilder(wsOp);
+
+  // TODO We need something like an scf;execute here, but that is not registered
+  // so using fir.if for now but it looks like it does not support multiple
+  // blocks so it doesnt work for multi block case...
+  auto ifOp = rootBuilder.create<fir::IfOp>(
+      loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
+  ifOp.getThenRegion().front().erase();
+
+  parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);
+
+  Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
+  assert(isa<omp::TerminatorOp>(terminatorOp));
+  OpBuilder termBuilder(terminatorOp);
 
   if (!wsOp.getNowait())
-    rootBuilder.create<omp::BarrierOp>(loc);
+    termBuilder.create<omp::BarrierOp>(loc);
+
+  termBuilder.create<fir::ResultOp>(loc, ValueRange());
 
+  terminatorOp->erase();
   wsOp->erase();
 
   return;

>From 58108d0d1fedab588448edf7e097698cc2f99a0e Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 00:33:57 +0900
Subject: [PATCH 22/36] Leave comments for shouldUseWorkshareLowering

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 21 +++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 8e79d1401c01c..40dae0fd848ef 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -40,10 +40,23 @@ using namespace mlir;
 
 namespace flangomp {
 bool shouldUseWorkshareLowering(Operation *op) {
-  auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
-  if (!workshare)
-    return false;
-  return workshare->getParentOfType<omp::ParallelOp>();
+  // TODO this is insufficient, as we could have
+  // omp.parallel {
+  //   omp.workshare {
+  //     omp.parallel {
+  //       hlfir.elemental {}
+  //
+  // Then this hlfir.elemental shall _not_ use the lowering for workshare
+  //
+  // Standard says:
+  //   For a parallel construct, the construct is a unit of work with respect to
+  //   the workshare construct. The statements contained in the parallel
+  //   construct are executed by a new thread team.
+  //
+  // TODO similarly for single, critical, etc. Need to think through the
+  // patterns and implement this function.
+  //
+  return op->getParentOfType<omp::WorkshareOp>();
 }
 } // namespace flangomp
 

>From 26a612d2d2558595995b64f921dccdf6de5494f6 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 13:14:38 +0900
Subject: [PATCH 23/36] Use copyprivate to scatter val from omp.single

TODO still need to implement copy function
TODO transitive check for usage outside of omp.single not imiplemented yet
---
 .../include/flang/Optimizer/OpenMP/Passes.td  |   3 +-
 flang/include/flang/Tools/CLOptions.inc       |   2 +-
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 138 ++++++++++++++----
 3 files changed, 109 insertions(+), 34 deletions(-)

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 1c9d75d8cfaa1..041240cad12eb 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -37,7 +37,8 @@ def FunctionFiltering : Pass<"omp-function-filtering"> {
   ];
 }
 
-def LowerWorkshare : Pass<"lower-workshare"> {
+// Needs to be scheduled on Module as we create functions in it
+def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index b256dba3c9294..a565effebfa92 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,7 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  addNestedPassToAllTopLevelOperations(pm, flangomp::createLowerWorkshare);
+  pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 40dae0fd848ef..950737fccada7 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -8,25 +8,27 @@
 // Lower omp workshare construct.
 //===----------------------------------------------------------------------===//
 
-#include "flang/Optimizer/Dialect/FIROps.h"
-#include "flang/Optimizer/Dialect/FIRType.h"
-#include "flang/Optimizer/OpenMP/Passes.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVectorExtras.h"
-#include "llvm/ADT/iterator_range.h"
-
+#include <flang/Optimizer/Builder/FIRBuilder.h>
+#include <flang/Optimizer/Dialect/FIROps.h>
+#include <flang/Optimizer/Dialect/FIRType.h>
+#include <flang/Optimizer/HLFIR/HLFIROps.h>
+#include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/ADT/SmallVectorExtras.h>
+#include <llvm/ADT/iterator_range.h>
+#include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
-#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/PatternMatch.h>
 #include <mlir/IR/Visitors.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
+
 #include <variant>
 
 namespace flangomp {
@@ -71,6 +73,8 @@ static bool isSupportedByFirAlloca(Type ty) {
 }
 
 static bool mustParallelizeOp(Operation *op) {
+  // TODO as in shouldUseWorkshareLowering we be careful not to pick up
+  // workshare_loop_wrapper in nested omp.parallel ops
   return op
       ->walk(
           [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
@@ -78,7 +82,33 @@ static bool mustParallelizeOp(Operation *op) {
 }
 
 static bool isSafeToParallelize(Operation *op) {
-  return isa<fir::DeclareOp>(op) || isPure(op);
+  return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
+         isMemoryEffectFree(op);
+}
+
+static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
+                                         fir::FirOpBuilder builder) {
+  mlir::ModuleOp module = builder.getModule();
+  mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
+
+  std::string copyFuncName =
+      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+
+  if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
+    return decl;
+  // create function
+  mlir::OpBuilder::InsertionGuard guard(builder);
+  mlir::OpBuilder modBuilder(module.getBodyRegion());
+  llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
+  auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
+  mlir::func::FuncOp funcOp =
+      modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
+  funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
+  builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
+                      {loc, loc});
+  builder.setInsertionPointToStart(&funcOp.getRegion().back());
+  builder.create<mlir::func::ReturnOp>(loc);
+  return funcOp;
 }
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
@@ -86,19 +116,23 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
   Operation *parentOp = sourceRegion.getParentOp();
   OpBuilder rootBuilder(sourceRegion.getContext());
 
+  ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
+  OpBuilder copyFuncBuilder(m.getBodyRegion());
+  fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
+
   // TODO need to copyprivate the alloca's
-  auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
-                              IRMapping singleMapping) {
-    OpBuilder allocaBuilder(&targetRegion.front().front());
+  auto mapReloadedValue =
+      [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
+          OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
     if (auto reloaded = rootMapping.lookupOrNull(v))
-      return;
+      return nullptr;
     Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
     Value alloc, reloaded;
     if (isSupportedByFirAlloca(ty)) {
       alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
       singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
-      reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
+      reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
     } else {
       auto one = allocaBuilder.create<LLVM::ConstantOp>(
           loc, allocaBuilder.getI32Type(), 1);
@@ -109,21 +143,25 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               loc, llvmPtrTy, singleMapping.lookup(v))
                           .getResult(0);
       singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
-      reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
+      reloaded = parallelBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
       reloaded =
-          rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
+          parallelBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
               .getResult(0);
     }
     rootMapping.map(v, reloaded);
+    return alloc;
   };
 
-  auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
+                          OpBuilder singleBuilder,
+                          OpBuilder parallelBuilder) -> SmallVector<Value> {
     IRMapping singleMapping = rootMapping;
+    SmallVector<Value> copyPrivate;
 
     for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
       singleBuilder.clone(op, singleMapping);
       if (isSafeToParallelize(&op)) {
-        rootBuilder.clone(op, rootMapping);
+        parallelBuilder.clone(op, rootMapping);
       } else {
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
@@ -132,16 +170,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
             Operation *user = use.getOwner();
             while (user->getParentOp() != parentOp)
               user = user->getParentOp();
-            if (!(user->isBeforeInBlock(&*sr.end) &&
-                  sr.begin->isBeforeInBlock(user))) {
-              // We need to reload
-              mapReloadedValue(use.get(), singleBuilder, singleMapping);
+            // TODO we need to look at transitively used vals
+            if (true || !(user->isBeforeInBlock(&*sr.end) &&
+                          sr.begin->isBeforeInBlock(user))) {
+              auto alloc =
+                  mapReloadedValue(use.get(), allocaBuilder, singleBuilder,
+                                   parallelBuilder, singleMapping);
+              if (alloc)
+                copyPrivate.push_back(alloc);
             }
           }
         }
       }
     }
     singleBuilder.create<omp::TerminatorOp>(loc);
+    return copyPrivate;
   };
 
   // TODO Need to handle these (clone them) in dominator tree order
@@ -178,14 +221,45 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
       bool isLast = i + 1 == regions.size();
       if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+        OpBuilder singleBuilder(sourceRegion.getContext());
+        Block *singleBlock = new Block();
+        singleBuilder.setInsertionPointToStart(singleBlock);
+
+        OpBuilder allocaBuilder(sourceRegion.getContext());
+        Block *allocaBlock = new Block();
+        allocaBuilder.setInsertionPointToStart(allocaBlock);
+
+        OpBuilder parallelBuilder(sourceRegion.getContext());
+        Block *parallelBlock = new Block();
+        parallelBuilder.setInsertionPointToStart(parallelBlock);
+
         omp::SingleOperands singleOperands;
         if (isLast)
           singleOperands.nowait = rootBuilder.getUnitAttr();
+        auto insPtAtSingle = rootBuilder.saveInsertionPoint();
+        singleOperands.copyprivateVars =
+            moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
+                         singleBuilder, parallelBuilder);
+        for (auto var : singleOperands.copyprivateVars) {
+          Type ty;
+          if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
+            ty = firAlloca.getAllocatedType();
+          } else {
+            llvm_unreachable("unexpected");
+          }
+          mlir::func::FuncOp funcOp =
+              createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
+          singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
+        }
         omp::SingleOp singleOp =
             rootBuilder.create<omp::SingleOp>(loc, singleOperands);
-        OpBuilder singleBuilder(singleOp);
-        singleBuilder.createBlock(&singleOp.getRegion());
-        moveToSingle(std::get<SingleRegion>(opOrSingle), singleBuilder);
+        singleOp.getRegion().push_back(singleBlock);
+        rootBuilder.getInsertionBlock()->getOperations().splice(
+            rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
+        targetRegion.front().getOperations().splice(
+            singleOp->getIterator(), allocaBlock->getOperations());
+        delete allocaBlock;
+        delete parallelBlock;
       } else {
         auto op = std::get<Operation *>(opOrSingle);
         if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {

>From 60043f94c421675584511a8b0d4fc737e7293d3e Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 13:39:49 +0900
Subject: [PATCH 24/36] Transitively check for users outisde of single op

TODO need to implement copy func
TODO need to hoist allocas outside of single regions
---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 51 ++++++++++++++-----
 1 file changed, 37 insertions(+), 14 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 950737fccada7..2e88d852ff2cb 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -111,6 +111,38 @@ static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
   return funcOp;
 }
 
+static bool isUserOutsideSR(Operation *user, Operation *parentOp,
+                            SingleRegion sr) {
+  while (user->getParentOp() != parentOp)
+    user = user->getParentOp();
+  return sr.begin->getBlock() != user->getBlock() ||
+         !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
+}
+
+static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
+  Block *srBlock = sr.begin->getBlock();
+  Operation *parentOp = srBlock->getParentOp();
+
+  for (auto &use : v.getUses()) {
+    Operation *user = use.getOwner();
+    if (isUserOutsideSR(user, parentOp, sr))
+      return true;
+
+    // Results of nested users cannot be used outside of the SR
+    if (user->getBlock() != srBlock)
+      continue;
+
+    // A non-safe to parallelize operation will be handled separately
+    if (!isSafeToParallelize(user))
+      continue;
+
+    for (auto res : user->getResults())
+      if (isTransitivelyUsedOutside(res, sr))
+        return true;
+  }
+  return false;
+}
+
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
   Operation *parentOp = sourceRegion.getParentOp();
@@ -166,19 +198,11 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
         for (auto res : op.getResults()) {
-          for (auto &use : res.getUses()) {
-            Operation *user = use.getOwner();
-            while (user->getParentOp() != parentOp)
-              user = user->getParentOp();
-            // TODO we need to look at transitively used vals
-            if (true || !(user->isBeforeInBlock(&*sr.end) &&
-                          sr.begin->isBeforeInBlock(user))) {
-              auto alloc =
-                  mapReloadedValue(use.get(), allocaBuilder, singleBuilder,
-                                   parallelBuilder, singleMapping);
-              if (alloc)
-                copyPrivate.push_back(alloc);
-            }
+          if (isTransitivelyUsedOutside(res, sr)) {
+            auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
+                                          parallelBuilder, singleMapping);
+            if (alloc)
+              copyPrivate.push_back(alloc);
           }
         }
       }
@@ -236,7 +260,6 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         omp::SingleOperands singleOperands;
         if (isLast)
           singleOperands.nowait = rootBuilder.getUnitAttr();
-        auto insPtAtSingle = rootBuilder.saveInsertionPoint();
         singleOperands.copyprivateVars =
             moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
                          singleBuilder, parallelBuilder);

>From 7f3d4cb2d9cf373bc4acc856420ba58ddd339e7f Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 14:30:43 +0900
Subject: [PATCH 25/36] Add tests

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp |  25 +-
 .../Transforms/OpenMP/lower-workshare.mlir    | 230 +++++++++++++-----
 2 files changed, 188 insertions(+), 67 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 2e88d852ff2cb..30af2556cf4ca 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -18,6 +18,7 @@
 #include <llvm/ADT/iterator_range.h>
 #include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
 #include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
 #include <mlir/IR/BuiltinOps.h>
@@ -75,6 +76,14 @@ static bool isSupportedByFirAlloca(Type ty) {
 static bool mustParallelizeOp(Operation *op) {
   // TODO as in shouldUseWorkshareLowering we be careful not to pick up
   // workshare_loop_wrapper in nested omp.parallel ops
+  //
+  // e.g.
+  //
+  // omp.parallel {
+  //   omp.workshare {
+  //     omp.parallel {
+  //       omp.workshare {
+  //         omp.workshare_loop_wrapper {}
   return op
       ->walk(
           [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
@@ -89,10 +98,14 @@ static bool isSafeToParallelize(Operation *op) {
 static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
                                          fir::FirOpBuilder builder) {
   mlir::ModuleOp module = builder.getModule();
-  mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
-
-  std::string copyFuncName =
-      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+  std::string copyFuncName;
+  if (auto rt = dyn_cast<fir::ReferenceType>(varType)) {
+    mlir::Type eleTy = rt.getEleTy();
+    copyFuncName =
+        fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+  } else {
+    copyFuncName = "_workshare_copy_llvm_ptr";
+  }
 
   if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
     return decl;
@@ -145,9 +158,7 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
-  Operation *parentOp = sourceRegion.getParentOp();
   OpBuilder rootBuilder(sourceRegion.getContext());
-
   ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
   OpBuilder copyFuncBuilder(m.getBodyRegion());
   fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
@@ -268,7 +279,7 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
             ty = firAlloca.getAllocatedType();
           } else {
-            llvm_unreachable("unexpected");
+            ty = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
           }
           mlir::func::FuncOp funcOp =
               createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index cb5791d35916a..19123e71cacf6 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -1,80 +1,190 @@
-// RUN: fir-opt --lower-workshare %s | FileCheck %s
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
-module {
-// CHECK-LABEL:   func.func @simple(
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.parallel {
+    omp.workshare {
+      %c42 = arith.constant 42 : index
+      %c1_i32 = arith.constant 1 : i32
+      %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+      %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+      %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+      %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+      %true = arith.constant true
+      %c1 = arith.constant 1 : index
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+          %8 = fir.load %7 : !fir.ref<i32>
+          %9 = arith.subi %8, %c1_i32 : i32
+          %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+          hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+      %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+      fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+
+// -----
+
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.workshare {
+    %c1_i32 = arith.constant 1 : i32
+    %alloc = fir.alloca i32
+    fir.store %c1_i32 to %alloc : !fir.ref<i32>
+    %c42 = arith.constant 42 : index
+    %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+    %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+    %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+    %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+    %true = arith.constant true
+    %c1 = arith.constant 1 : index
+    "omp.workshare_loop_wrapper"() ({
+      omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+        %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+        %8 = fir.load %7 : !fir.ref<i32>
+        %ld = fir.load %alloc : !fir.ref<i32>
+        %n8 = arith.subi %8, %ld : i32
+        %9 = arith.subi %n8, %c1_i32 : i32
+        %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+        hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+        omp.yield
+      }
+      omp.terminator
+    }) : () -> ()
+    %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    "test.test1"(%alloc) : (!fir.ref<i32>) -> ()
+    hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+    fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+    omp.terminator
+  }
+  return
+}
+
+
+// CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
 // CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK:           %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
-// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           %[[VAL_4:.*]] = arith.constant true
 // CHECK:           omp.parallel {
-// CHECK:             omp.single {
-// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]]#0 : !fir.ref<!fir.array<42xi32>> to !llvm.ptr
-// CHECK:               llvm.store %[[VAL_9]], %[[VAL_5]] : !llvm.ptr, !llvm.ptr
-// CHECK:               %[[VAL_10:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               fir.store %[[VAL_11]]#0 to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             fir.if %[[VAL_4]] {
+// CHECK:               %[[VAL_5:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:                 %[[VAL_6:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:                 %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_6]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:                 %[[VAL_8:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:                 fir.store %[[VAL_8]] to %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:                 %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_6]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               omp.wsloop {
+// CHECK:                 omp.loop_nest (%[[VAL_14:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                   %[[VAL_15:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_14]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                   %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
+// CHECK:                   %[[VAL_17:.*]] = arith.subi %[[VAL_16]], %[[VAL_2]] : i32
+// CHECK:                   %[[VAL_18:.*]] = hlfir.designate %[[VAL_13]]#0 (%[[VAL_14]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                   hlfir.assign %[[VAL_17]] to %[[VAL_18]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                   omp.yield
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.single nowait {
+// CHECK:                 hlfir.assign %[[VAL_13]]#0 to %[[VAL_11]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:                 fir.freemem %[[VAL_13]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.barrier
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_workshare_copy_llvm_ptr(
+// CHECK-SAME:                                                %[[VAL_0:.*]]: !llvm.ptr,
+// CHECK-SAME:                                                %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @wsfunc(
+// CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_5:.*]] = arith.constant true
+// CHECK:           fir.if %[[VAL_5]] {
+// CHECK:             %[[VAL_6:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
+// CHECK:             %[[VAL_7:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_6]] -> @_workshare_copy_llvm_ptr : !llvm.ptr, %[[VAL_7]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               %[[VAL_8:.*]] = fir.alloca i32
+// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !fir.ref<i32> to !llvm.ptr
+// CHECK:               llvm.store %[[VAL_9]], %[[VAL_6]] : !llvm.ptr, !llvm.ptr
+// CHECK:               fir.store %[[VAL_3]] to %[[VAL_8]] : !fir.ref<i32>
+// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_12:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_12]] to %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             %[[VAL_12:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.ptr
-// CHECK:             %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.ptr to !fir.ref<!fir.array<42xi32>>
-// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_14:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr -> !llvm.ptr
+// CHECK:             %[[VAL_15:.*]] = builtin.unrealized_conversion_cast %[[VAL_14]] : !llvm.ptr to !fir.ref<i32>
+// CHECK:             %[[VAL_16:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_17:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_16]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_18]](%[[VAL_16]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_2]] : i32
-// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_14]] (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.loop_nest (%[[VAL_20:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_17]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_23:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_24:.*]] = arith.subi %[[VAL_22]], %[[VAL_23]] : i32
+// CHECK:                 %[[VAL_25:.*]] = arith.subi %[[VAL_24]], %[[VAL_3]] : i32
+// CHECK:                 %[[VAL_26:.*]] = hlfir.designate %[[VAL_19]]#0 (%[[VAL_20]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_25]] to %[[VAL_26]] temporary_lhs : i32, !fir.ref<i32>
 // CHECK:                 omp.yield
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.single nowait {
-// CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
+// CHECK:               "test.test1"(%[[VAL_15]]) : (!fir.ref<i32>) -> ()
+// CHECK:               hlfir.assign %[[VAL_19]]#0 to %[[VAL_17]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_19]]#0 : !fir.heap<!fir.array<42xi32>>
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.barrier
-// CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
-  func.func @simple(%arg0: !fir.ref<!fir.array<42xi32>>) {
-    omp.parallel {
-      omp.workshare {
-        %c42 = arith.constant 42 : index
-        %c1_i32 = arith.constant 1 : i32
-        %0 = fir.shape %c42 : (index) -> !fir.shape<1>
-        %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-        %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-        %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-        %true = arith.constant true
-        %c1 = arith.constant 1 : index
-        "omp.workshare_loop_wrapper"() ({
-          omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
-            %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-            %8 = fir.load %7 : !fir.ref<i32>
-            %9 = arith.subi %8, %c1_i32 : i32
-            %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-            hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
-            omp.yield
-          }
-          omp.terminator
-        }) : () -> ()
-        %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-        fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
-        omp.terminator
-      }
-      omp.terminator
-    }
-    return
-  }
-}
+

>From 2f16bb4239df024217f6808f4f4fd7acc2fb194e Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 15:50:58 +0900
Subject: [PATCH 26/36] Hoist allocas

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 10 ++-
 .../Transforms/OpenMP/lower-workshare.mlir    | 69 +++++++++----------
 2 files changed, 41 insertions(+), 38 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 30af2556cf4ca..d0cd235d3eb07 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -163,7 +163,6 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
   OpBuilder copyFuncBuilder(m.getBodyRegion());
   fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
 
-  // TODO need to copyprivate the alloca's
   auto mapReloadedValue =
       [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
           OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
@@ -202,10 +201,17 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     SmallVector<Value> copyPrivate;
 
     for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
-      singleBuilder.clone(op, singleMapping);
       if (isSafeToParallelize(&op)) {
+        singleBuilder.clone(op, singleMapping);
         parallelBuilder.clone(op, rootMapping);
+      } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
+        auto hoisted =
+            cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
+        rootMapping.map(&*alloca, &*hoisted);
+        rootMapping.map(alloca.getResult(), hoisted.getResult());
+        copyPrivate.push_back(hoisted);
       } else {
+        singleBuilder.clone(op, singleMapping);
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
         for (auto res : op.getResults()) {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 19123e71cacf6..b78cfd80e17ac 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -1,5 +1,7 @@
 // RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
+// checks:
+// nowait on final omp.single
 func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   omp.parallel {
     omp.workshare {
@@ -37,6 +39,8 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // -----
 
+// checks:
+// fir.alloca hoisted out and copyprivate'd
 func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   omp.workshare {
     %c1_i32 = arith.constant 1 : i32
@@ -73,7 +77,6 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   return
 }
 
-
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
@@ -130,9 +133,9 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           return
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_workshare_copy_llvm_ptr(
-// CHECK-SAME:                                                %[[VAL_0:.*]]: !llvm.ptr,
-// CHECK-SAME:                                                %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK-LABEL:   func.func private @_workshare_copy_i32(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: !fir.ref<i32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: !fir.ref<i32>) {
 // CHECK:           return
 // CHECK:         }
 
@@ -141,46 +144,40 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK:           %[[VAL_5:.*]] = arith.constant true
-// CHECK:           fir.if %[[VAL_5]] {
-// CHECK:             %[[VAL_6:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
-// CHECK:             %[[VAL_7:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:             omp.single copyprivate(%[[VAL_6]] -> @_workshare_copy_llvm_ptr : !llvm.ptr, %[[VAL_7]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:               %[[VAL_8:.*]] = fir.alloca i32
-// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !fir.ref<i32> to !llvm.ptr
-// CHECK:               llvm.store %[[VAL_9]], %[[VAL_6]] : !llvm.ptr, !llvm.ptr
-// CHECK:               fir.store %[[VAL_3]] to %[[VAL_8]] : !fir.ref<i32>
-// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_12:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               fir.store %[[VAL_12]] to %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_4:.*]] = arith.constant true
+// CHECK:           fir.if %[[VAL_4]] {
+// CHECK:             %[[VAL_5:.*]] = fir.alloca i32
+// CHECK:             %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_6]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i32>
+// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_9:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_9]] to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             %[[VAL_14:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr -> !llvm.ptr
-// CHECK:             %[[VAL_15:.*]] = builtin.unrealized_conversion_cast %[[VAL_14]] : !llvm.ptr to !fir.ref<i32>
-// CHECK:             %[[VAL_16:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:             %[[VAL_17:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_16]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:             %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_18]](%[[VAL_16]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_11:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_20:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_17]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_23:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_24:.*]] = arith.subi %[[VAL_22]], %[[VAL_23]] : i32
-// CHECK:                 %[[VAL_25:.*]] = arith.subi %[[VAL_24]], %[[VAL_3]] : i32
-// CHECK:                 %[[VAL_26:.*]] = hlfir.designate %[[VAL_19]]#0 (%[[VAL_20]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_25]] to %[[VAL_26]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = fir.load %[[VAL_5]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_19:.*]] = arith.subi %[[VAL_17]], %[[VAL_18]] : i32
+// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_3]] : i32
+// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_20]] to %[[VAL_21]] temporary_lhs : i32, !fir.ref<i32>
 // CHECK:                 omp.yield
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.single nowait {
-// CHECK:               "test.test1"(%[[VAL_15]]) : (!fir.ref<i32>) -> ()
-// CHECK:               hlfir.assign %[[VAL_19]]#0 to %[[VAL_17]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_19]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:               "test.test1"(%[[VAL_5]]) : (!fir.ref<i32>) -> ()
+// CHECK:               hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.barrier

>From 2131c579b1c1ddc961f47a180d020a6a4dc1af12 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 15:52:23 +0900
Subject: [PATCH 27/36] More tests

---
 .../Transforms/OpenMP/lower-workshare2.mlir   | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare2.mlir

diff --git a/flang/test/Transforms/OpenMP/lower-workshare2.mlir b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
new file mode 100644
index 0000000000000..325a40d418445
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
@@ -0,0 +1,21 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @nonowait
+func.func @nonowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  // CHECK: omp.barrier
+  omp.workshare {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @nowait
+func.func @nowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  // CHECK-NOT: omp.barrier
+  omp.workshare nowait {
+    omp.terminator
+  }
+  return
+}

>From b9a4970ffc0b29cadaf2f831db580e1c3eccea83 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:13:39 +0900
Subject: [PATCH 28/36] Emit body for copy func

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 44 +++++--------------
 .../Transforms/OpenMP/lower-workshare.mlir    |  6 +++
 2 files changed, 17 insertions(+), 33 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index d0cd235d3eb07..20f45296a8159 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -69,10 +69,6 @@ struct SingleRegion {
   Block::iterator begin, end;
 };
 
-static bool isSupportedByFirAlloca(Type ty) {
-  return !isa<fir::ReferenceType>(ty);
-}
-
 static bool mustParallelizeOp(Operation *op) {
   // TODO as in shouldUseWorkshareLowering we be careful not to pick up
   // workshare_loop_wrapper in nested omp.parallel ops
@@ -98,14 +94,10 @@ static bool isSafeToParallelize(Operation *op) {
 static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
                                          fir::FirOpBuilder builder) {
   mlir::ModuleOp module = builder.getModule();
-  std::string copyFuncName;
-  if (auto rt = dyn_cast<fir::ReferenceType>(varType)) {
-    mlir::Type eleTy = rt.getEleTy();
-    copyFuncName =
-        fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
-  } else {
-    copyFuncName = "_workshare_copy_llvm_ptr";
-  }
+  auto rt = cast<fir::ReferenceType>(varType);
+  mlir::Type eleTy = rt.getEleTy();
+  std::string copyFuncName =
+      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
 
   if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
     return decl;
@@ -120,6 +112,10 @@ static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
   builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
                       {loc, loc});
   builder.setInsertionPointToStart(&funcOp.getRegion().back());
+
+  Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(0));
+  builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(1));
+
   builder.create<mlir::func::ReturnOp>(loc);
   return funcOp;
 }
@@ -168,28 +164,10 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
     if (auto reloaded = rootMapping.lookupOrNull(v))
       return nullptr;
-    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
-    Value alloc, reloaded;
-    if (isSupportedByFirAlloca(ty)) {
-      alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
-      singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
-      reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
-    } else {
-      auto one = allocaBuilder.create<LLVM::ConstantOp>(
-          loc, allocaBuilder.getI32Type(), 1);
-      alloc =
-          allocaBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
-      Value toStore = singleBuilder
-                          .create<UnrealizedConversionCastOp>(
-                              loc, llvmPtrTy, singleMapping.lookup(v))
-                          .getResult(0);
-      singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
-      reloaded = parallelBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
-      reloaded =
-          parallelBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
-              .getResult(0);
-    }
+    Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
+    singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+    Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
     rootMapping.map(v, reloaded);
     return alloc;
   };
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index b78cfd80e17ac..997bc8d79f9b3 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -80,6 +80,8 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
 // CHECK:           return
 // CHECK:         }
 
@@ -130,12 +132,16 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
 // CHECK:           return
 // CHECK:         }
 
 // CHECK-LABEL:   func.func private @_workshare_copy_i32(
 // CHECK-SAME:                                           %[[VAL_0:.*]]: !fir.ref<i32>,
 // CHECK-SAME:                                           %[[VAL_1:.*]]: !fir.ref<i32>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<i32>
 // CHECK:           return
 // CHECK:         }
 

>From c8ee47be7ab0ad22afec8d7b5c2d765e0daf9cda Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:59:34 +0900
Subject: [PATCH 29/36] Test the tmp storing logic

---
 .../Transforms/OpenMP/lower-workshare.mlir    |  2 -
 .../Transforms/OpenMP/lower-workshare3.mlir   | 74 +++++++++++++++++++
 2 files changed, 74 insertions(+), 2 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare3.mlir

diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 997bc8d79f9b3..063d3865065e0 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -36,7 +36,6 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   return
 }
 
-
 // -----
 
 // checks:
@@ -190,4 +189,3 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
-
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
new file mode 100644
index 0000000000000..84eded9450328
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -0,0 +1,74 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+
+// tests if the correct values are stored
+
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.parallel {
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK-NOT: fir.alloca
+    omp.workshare {
+
+      %t1 = "test.test1"() : () -> i32
+      // CHECK: %[[T1:.*]] = "test.test1"
+      // CHECK: fir.store %[[T1]]
+      %t2 = "test.test2"() : () -> i32
+      // CHECK: %[[T2:.*]] = "test.test2"
+      // CHECK: fir.store %[[T2]]
+      %t3 = "test.test3"() : () -> i32
+      // CHECK: %[[T3:.*]] = "test.test3"
+      // CHECK-NOT: fir.store %[[T3]]
+      %t4 = "test.test4"() : () -> i32
+      // CHECK: %[[T4:.*]] = "test.test4"
+      // CHECK: fir.store %[[T4]]
+      %t5 = "test.test5"() : () -> i32
+      // CHECK: %[[T5:.*]] = "test.test5"
+      // CHECK: fir.store %[[T5]]
+      %t6 = "test.test6"() : () -> i32
+      // CHECK: %[[T6:.*]] = "test.test6"
+      // CHECK-NOT: fir.store %[[T6]]
+
+
+      "test.test1"(%t1) : (i32) -> ()
+      "test.test1"(%t2) : (i32) -> ()
+      "test.test1"(%t3) : (i32) -> ()
+
+      %true = arith.constant true
+      fir.if %true {
+        "test.test2"(%t3) : (i32) -> ()
+      }
+
+      %c1_i32 = arith.constant 1 : i32
+
+      %t5_pure_use = arith.addi %t5, %c1_i32 : i32
+
+      %t6_mem_effect_use = "test.test8"(%t6) : (i32) -> i32
+      // CHECK: %[[T6_USE:.*]] = "test.test8"
+      // CHECK: fir.store %[[T6_USE]]
+
+      %c42 = arith.constant 42 : index
+      %c1 = arith.constant 1 : index
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          "test.test10"(%t1) : (i32) -> ()
+          "test.test10"(%t5_pure_use) : (i32) -> ()
+          "test.test10"(%t6_mem_effect_use) : (i32) -> ()
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+
+      "test.test10"(%t2) : (i32) -> ()
+      fir.if %true {
+        "test.test10"(%t4) : (i32) -> ()
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}

>From e0890a502fb29f2606090591c3a18b246ab1472f Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 20:24:36 +0900
Subject: [PATCH 30/36] Clean up trivially dead ops

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 32 ++++-------
 .../Transforms/OpenMP/lower-workshare3.mlir   |  2 +-
 .../Transforms/OpenMP/lower-workshare4.mlir   | 55 +++++++++++++++++++
 3 files changed, 68 insertions(+), 21 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare4.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 20f45296a8159..a147db2cb5d59 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -152,6 +152,14 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
   return false;
 }
 
+/// We clone pure operations in both the parallel and single blocks. this
+/// functions cleans them up if they end up with no uses
+static void cleanupBlock(Block *block) {
+  for (Operation &op : llvm::make_early_inc_range(*block))
+    if (isOpTriviallyDead(&op))
+      op.erase();
+}
+
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
   OpBuilder rootBuilder(sourceRegion.getContext());
@@ -258,13 +266,8 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         singleOperands.copyprivateVars =
             moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
                          singleBuilder, parallelBuilder);
+        cleanupBlock(singleBlock);
         for (auto var : singleOperands.copyprivateVars) {
-          Type ty;
-          if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
-            ty = firAlloca.getAllocatedType();
-          } else {
-            ty = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
-          }
           mlir::func::FuncOp funcOp =
               createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
           singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
@@ -302,6 +305,9 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 
     rootBuilder.clone(*block.getTerminator(), rootMapping);
   }
+
+  for (Block &targetBlock : targetRegion)
+    cleanupBlock(&targetBlock);
 }
 
 /// Lowers workshare to a sequence of single-thread regions and parallel loops
@@ -372,20 +378,6 @@ class LowerWorksharePass
 
       lowerWorkshare(wsOp);
     });
-
-    // Do folding
-    for (Operation *isolatedParent : parents) {
-      RewritePatternSet patterns(&getContext());
-      GreedyRewriteConfig config;
-      // prevent the pattern driver form merging blocks
-      config.enableRegionSimplification =
-          mlir::GreedySimplifyRegionLevel::Disabled;
-      if (failed(applyPatternsAndFoldGreedily(isolatedParent,
-                                              std::move(patterns), config))) {
-        emitError(isolatedParent->getLoc(), "error in lower workshare\n");
-        signalPassFailure();
-      }
-    }
   }
 };
 } // namespace
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
index 84eded9450328..aee95a464a31b 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -3,7 +3,7 @@
 
 // tests if the correct values are stored
 
-func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+func.func @wsfunc() {
   omp.parallel {
   // CHECK: fir.alloca
   // CHECK: fir.alloca
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
new file mode 100644
index 0000000000000..6cff0075b4fe5
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -0,0 +1,55 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+      %t1 = "test.test1"() : () -> i32
+
+      %c1 = arith.constant 1 : index
+      %c42 = arith.constant 42 : index
+
+      %c2 = arith.constant 2 : index
+      "test.test3"(%c2) : (index) -> ()
+
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          "test.test2"() : () -> ()
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @wsfunc() {
+// CHECK:           %[[VAL_0:.*]] = fir.alloca i32
+// CHECK:           omp.parallel {
+// CHECK:             %[[VAL_1:.*]] = arith.constant true
+// CHECK:             fir.if %[[VAL_1]] {
+// CHECK:               omp.single {
+// CHECK:                 %[[VAL_2:.*]] = "test.test1"() : () -> i32
+// CHECK:                 %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK:                 "test.test3"(%[[VAL_3]]) : (index) -> ()
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_5:.*]] = arith.constant 42 : index
+// CHECK:               omp.wsloop nowait {
+// CHECK:                 omp.loop_nest (%[[VAL_6:.*]]) : index = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_4]]) {
+// CHECK:                   "test.test2"() : () -> ()
+// CHECK:                   omp.yield
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.barrier
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+

>From 270f0c61767434df32f7432b9b5c4cde33b22d31 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 21:55:14 +0900
Subject: [PATCH 31/36] Only handle single-block regions for now

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp |  80 +++++----
 .../Transforms/OpenMP/lower-workshare.mlir    | 154 +++++++++---------
 .../Transforms/OpenMP/lower-workshare4.mlir   |  31 ++--
 3 files changed, 143 insertions(+), 122 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index a147db2cb5d59..5998489c13d38 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -13,12 +13,14 @@
 #include <flang/Optimizer/Dialect/FIRType.h>
 #include <flang/Optimizer/HLFIR/HLFIROps.h>
 #include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/BreadthFirstIterator.h>
 #include <llvm/ADT/STLExtras.h>
 #include <llvm/ADT/SmallVectorExtras.h>
 #include <llvm/ADT/iterator_range.h>
 #include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
 #include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
 #include <mlir/IR/BuiltinOps.h>
@@ -161,7 +163,8 @@ static void cleanupBlock(Block *block) {
 }
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
-                              IRMapping &rootMapping, Location loc) {
+                              IRMapping &rootMapping, Location loc,
+                              mlir::DominanceInfo &di) {
   OpBuilder rootBuilder(sourceRegion.getContext());
   ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
   OpBuilder copyFuncBuilder(m.getBodyRegion());
@@ -214,14 +217,19 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     return copyPrivate;
   };
 
-  // TODO Need to handle these (clone them) in dominator tree order
   for (Block &block : sourceRegion) {
-    rootBuilder.createBlock(
+    Block *targetBlock = rootBuilder.createBlock(
         &targetRegion, {}, block.getArgumentTypes(),
         llvm::map_to_vector(block.getArguments(),
                             [](BlockArgument arg) { return arg.getLoc(); }));
-    Operation *terminator = block.getTerminator();
+    rootMapping.map(&block, targetBlock);
+    rootMapping.map(block.getArguments(), targetBlock->getArguments());
+  }
 
+  auto handleOneBlock = [&](Block &block) {
+    Block &targetBlock = *rootMapping.lookup(&block);
+    rootBuilder.setInsertionPointToStart(&targetBlock);
+    Operation *terminator = block.getTerminator();
     SmallVector<std::variant<SingleRegion, Operation *>> regions;
 
     auto it = block.begin();
@@ -298,12 +306,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
           for (auto [region, clonedRegion] :
                llvm::zip(op->getRegions(), cloned->getRegions()))
-            parallelizeRegion(region, clonedRegion, rootMapping, loc);
+            parallelizeRegion(region, clonedRegion, rootMapping, loc, di);
         }
       }
     }
 
     rootBuilder.clone(*block.getTerminator(), rootMapping);
+  };
+
+  if (sourceRegion.hasOneBlock()) {
+    handleOneBlock(sourceRegion.front());
+  } else {
+    auto &domTree = di.getDomTree(&sourceRegion);
+    for (auto node : llvm::breadth_first(domTree.getRootNode())) {
+      handleOneBlock(*node->getBlock());
+    }
   }
 
   for (Block &targetBlock : targetRegion)
@@ -336,47 +353,46 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 ///
 /// Note that we allocate temporary memory for values in omp.single's which need
 /// to be accessed in all threads in the closest omp.parallel
-void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) {
   Location loc = wsOp->getLoc();
   IRMapping rootMapping;
 
   OpBuilder rootBuilder(wsOp);
 
-  // TODO We need something like an scf;execute here, but that is not registered
-  // so using fir.if for now but it looks like it does not support multiple
-  // blocks so it doesnt work for multi block case...
-  auto ifOp = rootBuilder.create<fir::IfOp>(
-      loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
-  ifOp.getThenRegion().front().erase();
-
-  parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);
-
-  Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
-  assert(isa<omp::TerminatorOp>(terminatorOp));
-  OpBuilder termBuilder(terminatorOp);
-
+  // TODO We need something like an scf.execute here, but that is not registered
+  // so using omp.workshare as a placeholder. We need this op as our
+  // parallelizeRegion works on regions and not blocks.
+  omp::WorkshareOp newOp =
+      rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands());
   if (!wsOp.getNowait())
-    termBuilder.create<omp::BarrierOp>(loc);
-
-  termBuilder.create<fir::ResultOp>(loc, ValueRange());
-
-  terminatorOp->erase();
+    rootBuilder.create<omp::BarrierOp>(loc);
+
+  parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc, di);
+
+  if (wsOp.getRegion().getBlocks().size() != 1)
+    return failure();
+
+  // Inline the contents of the placeholder workshare op into its parent block.
+  Block *theBlock = &newOp.getRegion().front();
+  Operation *term = theBlock->getTerminator();
+  Block *parentBlock = wsOp->getBlock();
+  parentBlock->getOperations().splice(newOp->getIterator(),
+                                      theBlock->getOperations());
+  assert(term->getNumOperands() == 0);
+  term->erase();
+  newOp->erase();
   wsOp->erase();
-
-  return;
+  return success();
 }
 
 class LowerWorksharePass
     : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
 public:
   void runOnOperation() override {
-    SmallPtrSet<Operation *, 8> parents;
+    mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>();
     getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
-      Operation *isolatedParent =
-          wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
-      parents.insert(isolatedParent);
-
-      lowerWorkshare(wsOp);
+      if (failed(lowerWorkshare(wsOp, di)))
+        signalPassFailure();
     });
   }
 };
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 063d3865065e0..b31e951223d56 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -86,43 +86,46 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant true
 // CHECK:           omp.parallel {
-// CHECK:             fir.if %[[VAL_4]] {
-// CHECK:               %[[VAL_5:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:               omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:                 %[[VAL_6:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:                 %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_6]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:                 %[[VAL_8:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:                 fir.store %[[VAL_8]] to %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:                 %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_6]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               omp.wsloop {
-// CHECK:                 omp.loop_nest (%[[VAL_14:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                   %[[VAL_15:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_14]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                   %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
-// CHECK:                   %[[VAL_17:.*]] = arith.subi %[[VAL_16]], %[[VAL_2]] : i32
-// CHECK:                   %[[VAL_18:.*]] = hlfir.designate %[[VAL_13]]#0 (%[[VAL_14]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                   hlfir.assign %[[VAL_17]] to %[[VAL_18]] temporary_lhs : i32, !fir.ref<i32>
-// CHECK:                   omp.yield
-// CHECK:                 }
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               omp.single nowait {
-// CHECK:                 hlfir.assign %[[VAL_13]]#0 to %[[VAL_11]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:                 fir.freemem %[[VAL_13]]#0 : !fir.heap<!fir.array<42xi32>>
-// CHECK:                 omp.terminator
+// CHECK:             %[[VAL_1:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               %[[VAL_2:.*]] = arith.constant 42 : index
+// CHECK:               %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_5]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_7:.*]] = arith.constant 42 : index
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : i32
+// CHECK:             %[[VAL_9:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_9]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_11:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_11]](%[[VAL_9]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_13:.*]] = arith.constant true
+// CHECK:             %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_14]]) to (%[[VAL_7]]) inclusive step (%[[VAL_14]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_10]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_8]] : i32
+// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                 omp.yield
 // CHECK:               }
-// CHECK:               omp.barrier
+// CHECK:               omp.terminator
 // CHECK:             }
+// CHECK:             omp.single nowait {
+// CHECK:               %[[VAL_20:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               %[[VAL_21:.*]] = fir.insert_value %[[VAL_20]], %[[VAL_13]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               hlfir.assign %[[VAL_12]]#0 to %[[VAL_10]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_22:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             %[[VAL_23:.*]] = fir.insert_value %[[VAL_22]], %[[VAL_13]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             omp.barrier
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return
@@ -146,46 +149,51 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_4:.*]] = arith.constant true
-// CHECK:           fir.if %[[VAL_4]] {
-// CHECK:             %[[VAL_5:.*]] = fir.alloca i32
-// CHECK:             %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:             omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_6]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:               fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i32>
-// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_9:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               fir.store %[[VAL_9]] to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               omp.terminator
-// CHECK:             }
-// CHECK:             %[[VAL_11:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:             %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_18:.*]] = fir.load %[[VAL_5]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_19:.*]] = arith.subi %[[VAL_17]], %[[VAL_18]] : i32
-// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_3]] : i32
-// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_20]] to %[[VAL_21]] temporary_lhs : i32, !fir.ref<i32>
-// CHECK:                 omp.yield
-// CHECK:               }
-// CHECK:               omp.terminator
-// CHECK:             }
-// CHECK:             omp.single nowait {
-// CHECK:               "test.test1"(%[[VAL_5]]) : (!fir.ref<i32>) -> ()
-// CHECK:               hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
-// CHECK:               omp.terminator
+// CHECK:           %[[VAL_1:.*]] = fir.alloca i32
+// CHECK:           %[[VAL_2:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_2]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_1]] : !fir.ref<i32>
+// CHECK:             %[[VAL_4:.*]] = arith.constant 42 : index
+// CHECK:             %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_7:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:             fir.store %[[VAL_7]] to %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_7]](%[[VAL_5]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           %[[VAL_9:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_10:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_11:.*]] = fir.shape %[[VAL_10]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_13:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_15:.*]] = arith.constant true
+// CHECK:           %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:           omp.wsloop {
+// CHECK:             omp.loop_nest (%[[VAL_17:.*]]) : index = (%[[VAL_16]]) to (%[[VAL_10]]) inclusive step (%[[VAL_16]]) {
+// CHECK:               %[[VAL_18:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_17]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:               %[[VAL_19:.*]] = fir.load %[[VAL_18]] : !fir.ref<i32>
+// CHECK:               %[[VAL_20:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
+// CHECK:               %[[VAL_21:.*]] = arith.subi %[[VAL_19]], %[[VAL_20]] : i32
+// CHECK:               %[[VAL_22:.*]] = arith.subi %[[VAL_21]], %[[VAL_9]] : i32
+// CHECK:               %[[VAL_23:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_17]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:               hlfir.assign %[[VAL_22]] to %[[VAL_23]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.yield
 // CHECK:             }
-// CHECK:             omp.barrier
+// CHECK:             omp.terminator
 // CHECK:           }
+// CHECK:           omp.single nowait {
+// CHECK:             %[[VAL_24:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             %[[VAL_25:.*]] = fir.insert_value %[[VAL_24]], %[[VAL_15]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             "test.test1"(%[[VAL_1]]) : (!fir.ref<i32>) -> ()
+// CHECK:             hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:             fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           %[[VAL_26:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:           %[[VAL_27:.*]] = fir.insert_value %[[VAL_26]], %[[VAL_15]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:           omp.barrier
 // CHECK:           return
 // CHECK:         }
+
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
index 6cff0075b4fe5..d695a1c354517 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare4.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -29,25 +29,22 @@ func.func @wsfunc() {
 // CHECK-LABEL:   func.func @wsfunc() {
 // CHECK:           %[[VAL_0:.*]] = fir.alloca i32
 // CHECK:           omp.parallel {
-// CHECK:             %[[VAL_1:.*]] = arith.constant true
-// CHECK:             fir.if %[[VAL_1]] {
-// CHECK:               omp.single {
-// CHECK:                 %[[VAL_2:.*]] = "test.test1"() : () -> i32
-// CHECK:                 %[[VAL_3:.*]] = arith.constant 2 : index
-// CHECK:                 "test.test3"(%[[VAL_3]]) : (index) -> ()
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:               %[[VAL_5:.*]] = arith.constant 42 : index
-// CHECK:               omp.wsloop nowait {
-// CHECK:                 omp.loop_nest (%[[VAL_6:.*]]) : index = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_4]]) {
-// CHECK:                   "test.test2"() : () -> ()
-// CHECK:                   omp.yield
-// CHECK:                 }
-// CHECK:                 omp.terminator
+// CHECK:             omp.single {
+// CHECK:               %[[VAL_1:.*]] = "test.test1"() : () -> i32
+// CHECK:               %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:               "test.test3"(%[[VAL_2]]) : (index) -> ()
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_4:.*]] = arith.constant 42 : index
+// CHECK:             omp.wsloop nowait {
+// CHECK:               omp.loop_nest (%[[VAL_5:.*]]) : index = (%[[VAL_3]]) to (%[[VAL_4]]) inclusive step (%[[VAL_3]]) {
+// CHECK:                 "test.test2"() : () -> ()
+// CHECK:                 omp.yield
 // CHECK:               }
-// CHECK:               omp.barrier
+// CHECK:               omp.terminator
 // CHECK:             }
+// CHECK:             omp.barrier
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return

>From 0a35b4c2f7e2b827a07f5cd6310233fd7a05a1de Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 6 Aug 2024 13:52:20 +0900
Subject: [PATCH 32/36] Fix tests for custom assembly for loop wrapper

---
 flang/test/Transforms/OpenMP/lower-workshare.mlir  | 8 ++++----
 flang/test/Transforms/OpenMP/lower-workshare3.mlir | 4 ++--
 flang/test/Transforms/OpenMP/lower-workshare4.mlir | 4 ++--
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index b31e951223d56..9347863dc4a60 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -13,7 +13,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
       %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
       %true = arith.constant true
       %c1 = arith.constant 1 : index
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
           %8 = fir.load %7 : !fir.ref<i32>
@@ -23,7 +23,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
       %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
       %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
       %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
@@ -52,7 +52,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
     %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
     %true = arith.constant true
     %c1 = arith.constant 1 : index
-    "omp.workshare_loop_wrapper"() ({
+    omp.workshare_loop_wrapper {
       omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
         %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
         %8 = fir.load %7 : !fir.ref<i32>
@@ -64,7 +64,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
         omp.yield
       }
       omp.terminator
-    }) : () -> ()
+    }
     %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
     %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
     %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
index aee95a464a31b..afb41d95e7198 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -52,7 +52,7 @@ func.func @wsfunc() {
 
       %c42 = arith.constant 42 : index
       %c1 = arith.constant 1 : index
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           "test.test10"(%t1) : (i32) -> ()
           "test.test10"(%t5_pure_use) : (i32) -> ()
@@ -60,7 +60,7 @@ func.func @wsfunc() {
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
 
       "test.test10"(%t2) : (i32) -> ()
       fir.if %true {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
index d695a1c354517..0a70007a9e78d 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare4.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -12,13 +12,13 @@ func.func @wsfunc() {
       %c2 = arith.constant 2 : index
       "test.test3"(%c2) : (index) -> ()
 
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           "test.test2"() : () -> ()
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
       omp.terminator
     }
     omp.terminator

>From 73953920573901bf9fae3d9e84aee59af603a8b6 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 14:43:50 +0900
Subject: [PATCH 33/36] Only run the lower workshare pass if openmp is enabled

---
 flang/include/flang/Tools/CLOptions.inc      |  7 ++++---
 flang/include/flang/Tools/CrossToolHelpers.h |  1 +
 flang/lib/Frontend/FrontendActions.cpp       | 10 +++++++++-
 flang/tools/bbc/bbc.cpp                      |  5 ++++-
 flang/tools/tco/tco.cpp                      |  1 +
 5 files changed, 19 insertions(+), 5 deletions(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index a565effebfa92..fbca7b6838ada 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -328,7 +328,7 @@ inline void createDefaultFIROptimizerPassPipeline(
 /// \param optLevel - optimization level used for creating FIR optimization
 ///   passes pipeline
 inline void createHLFIRToFIRPassPipeline(
-    mlir::PassManager &pm, llvm::OptimizationLevel optLevel = defaultOptLevel) {
+    mlir::PassManager &pm, bool enableOpenMP, llvm::OptimizationLevel optLevel = defaultOptLevel) {
   if (optLevel.isOptimizingForSpeed()) {
     addCanonicalizerPassWithoutRegionSimplification(pm);
     addNestedPassToAllTopLevelOperations(
@@ -345,7 +345,8 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  pm.addPass(flangomp::createLowerWorkshare());
+  if (enableOpenMP)
+    pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
@@ -416,7 +417,7 @@ inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
 ///   passes pipeline
 inline void createMLIRToLLVMPassPipeline(mlir::PassManager &pm,
     MLIRToLLVMPassPipelineConfig &config, llvm::StringRef inputFilename = {}) {
-  fir::createHLFIRToFIRPassPipeline(pm, config.OptLevel);
+  fir::createHLFIRToFIRPassPipeline(pm, config.EnableOpenMP, config.OptLevel);
 
   // Add default optimizer pass pipeline.
   fir::createDefaultFIROptimizerPassPipeline(pm, config);
diff --git a/flang/include/flang/Tools/CrossToolHelpers.h b/flang/include/flang/Tools/CrossToolHelpers.h
index 1d890fd8e1f6f..dd258864ff7f2 100644
--- a/flang/include/flang/Tools/CrossToolHelpers.h
+++ b/flang/include/flang/Tools/CrossToolHelpers.h
@@ -123,6 +123,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks {
       false; ///< Set no-signed-zeros-fp-math attribute for functions.
   bool UnsafeFPMath = false; ///< Set unsafe-fp-math attribute for functions.
   bool NSWOnLoopVarInc = false; ///< Add nsw flag to loop variable increments.
+  bool EnableOpenMP = false; ///< Enable OpenMP lowering.
 };
 
 struct OffloadModuleOpts {
diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index 5c86bd947ce73..db5c564933752 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -711,7 +711,11 @@ void CodeGenAction::lowerHLFIRToFIR() {
   pm.enableVerifier(/*verifyPasses=*/true);
 
   // Create the pass pipeline
-  fir::createHLFIRToFIRPassPipeline(pm, level);
+  fir::createHLFIRToFIRPassPipeline(
+      pm,
+      ci.getInvocation().getFrontendOpts().features.IsEnabled(
+          Fortran::common::LanguageFeature::OpenMP),
+      level);
   (void)mlir::applyPassManagerCLOptions(pm);
 
   if (!mlir::succeeded(pm.run(*mlirModule))) {
@@ -824,6 +828,10 @@ void CodeGenAction::generateLLVMIR() {
     config.VScaleMax = vsr->second;
   }
 
+  if (ci.getInvocation().getFrontendOpts().features.IsEnabled(
+          Fortran::common::LanguageFeature::OpenMP))
+    config.EnableOpenMP = true;
+
   if (ci.getInvocation().getLoweringOpts().getNSWOnLoopVarInc())
     config.NSWOnLoopVarInc = true;
 
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index 07eef065daf6f..681b23883df44 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -429,7 +429,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
 
     if (emitFIR && useHLFIR) {
       // lower HLFIR to FIR
-      fir::createHLFIRToFIRPassPipeline(pm, llvm::OptimizationLevel::O2);
+      fir::createHLFIRToFIRPassPipeline(pm, enableOpenMP,
+                                        llvm::OptimizationLevel::O2);
       if (mlir::failed(pm.run(mlirModule))) {
         llvm::errs() << "FATAL: lowering from HLFIR to FIR failed";
         return mlir::failure();
@@ -444,6 +445,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
 
     // Add O2 optimizer pass pipeline.
     MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
+    if (enableOpenMP)
+      config.EnableOpenMP = true;
     config.NSWOnLoopVarInc = setNSW;
     fir::registerDefaultInlinerPass(config);
     fir::createDefaultFIROptimizerPassPipeline(pm, config);
diff --git a/flang/tools/tco/tco.cpp b/flang/tools/tco/tco.cpp
index a8c64333109ae..06892cdc3f6a8 100644
--- a/flang/tools/tco/tco.cpp
+++ b/flang/tools/tco/tco.cpp
@@ -138,6 +138,7 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
       return mlir::failure();
   } else {
     MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
+    config.EnableOpenMP = true;  // assume the input contains OpenMP
     config.AliasAnalysis = true; // enabled when optimizing for speed
     if (codeGenLLVM) {
       // Run only CodeGen passes.

>From e2705fe4a350517012de0b04fef2670a270742cf Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 16:16:38 +0900
Subject: [PATCH 34/36] Implement some missing functionality

---
 flang/include/flang/Optimizer/OpenMP/Passes.h |   3 +
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 113 ++++++++++++------
 2 files changed, 81 insertions(+), 35 deletions(-)

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.h b/flang/include/flang/Optimizer/OpenMP/Passes.h
index 11fa4e59f891e..feb395f1a12db 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.h
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.h
@@ -25,6 +25,9 @@ namespace flangomp {
 #define GEN_PASS_REGISTRATION
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
 
+/// Impelements the logic specified in the 2.8.3  workshare Construct section of
+/// the OpenMP standard which specifies what statements or constructs shall be
+/// divided into units of work.
 bool shouldUseWorkshareLowering(mlir::Operation *op);
 
 } // namespace flangomp
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 5998489c13d38..e921b80d0c571 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -5,7 +5,15 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-// Lower omp workshare construct.
+//
+// This file implements the lowering of omp.workshare to other omp constructs.
+//
+// This pass is tasked with parallelizing the loops nested in
+// workshare_loop_wrapper while both the Fortran to mlir lowering and the hlfir
+// to fir lowering pipelines are responsible for emitting the
+// workshare_loop_wrapper ops where appropriate according to the
+// `shouldUseWorkshareLowering` function.
+//
 //===----------------------------------------------------------------------===//
 
 #include <flang/Optimizer/Builder/FIRBuilder.h>
@@ -44,25 +52,52 @@ namespace flangomp {
 using namespace mlir;
 
 namespace flangomp {
+
+// Checks for nesting pattern below as we need to avoid sharing the work of
+// statements which are nested in some constructs such as omp.critical or
+// another omp.parallel.
+//
+// omp.workshare { // `wsOp`
+//   ...
+//     omp.T { // `parent`
+//       ...
+//         `op`
+//
+template <typename T>
+static bool isNestedIn(omp::WorkshareOp wsOp, Operation *op) {
+  T parent = op->getParentOfType<T>();
+  if (!parent)
+    return false;
+  return wsOp->isProperAncestor(parent);
+}
+
 bool shouldUseWorkshareLowering(Operation *op) {
-  // TODO this is insufficient, as we could have
-  // omp.parallel {
-  //   omp.workshare {
-  //     omp.parallel {
-  //       hlfir.elemental {}
-  //
-  // Then this hlfir.elemental shall _not_ use the lowering for workshare
-  //
-  // Standard says:
-  //   For a parallel construct, the construct is a unit of work with respect to
-  //   the workshare construct. The statements contained in the parallel
-  //   construct are executed by a new thread team.
-  //
-  // TODO similarly for single, critical, etc. Need to think through the
-  // patterns and implement this function.
-  //
-  return op->getParentOfType<omp::WorkshareOp>();
+  auto parentWorkshare = op->getParentOfType<omp::WorkshareOp>();
+
+  if (!parentWorkshare)
+    return false;
+
+  if (isNestedIn<omp::CriticalOp>(parentWorkshare, op))
+    return false;
+
+  // 2.8.3  workshare Construct
+  // For a parallel construct, the construct is a unit of work with respect to
+  // the workshare construct. The statements contained in the parallel construct
+  // are executed by a new thread team.
+  if (isNestedIn<omp::ParallelOp>(parentWorkshare, op))
+    return false;
+
+  // 2.8.2  single Construct
+  // Binding The binding thread set for a single region is the current team. A
+  // single region binds to the innermost enclosing parallel region.
+  // Description Only one of the encountering threads will execute the
+  // structured block associated with the single construct.
+  if (isNestedIn<omp::SingleOp>(parentWorkshare, op))
+    return false;
+
+  return true;
 }
+
 } // namespace flangomp
 
 namespace {
@@ -72,19 +107,27 @@ struct SingleRegion {
 };
 
 static bool mustParallelizeOp(Operation *op) {
-  // TODO as in shouldUseWorkshareLowering we be careful not to pick up
-  // workshare_loop_wrapper in nested omp.parallel ops
-  //
-  // e.g.
-  //
-  // omp.parallel {
-  //   omp.workshare {
-  //     omp.parallel {
-  //       omp.workshare {
-  //         omp.workshare_loop_wrapper {}
   return op
-      ->walk(
-          [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
+      ->walk([&](Operation *nested) {
+        // We need to be careful not to pick up workshare_loop_wrapper in nested
+        // omp.parallel{omp.workshare} regions, i.e. make sure that `nested`
+        // binds to the workshare region we are currently handling.
+        //
+        // For example:
+        //
+        // omp.parallel {
+        //   omp.workshare { // currently handling this
+        //     omp.parallel {
+        //       omp.workshare { // nested workshare
+        //         omp.workshare_loop_wrapper {}
+        //
+        // Therefore, we skip if we encounter a nested omp.workshare.
+        if (isa<omp::WorkshareOp>(op))
+          WalkResult::skip();
+        if (isa<omp::WorkshareLoopWrapperOp>(op))
+          WalkResult::interrupt();
+        WalkResult::advance();
+      })
       .wasInterrupted();
 }
 
@@ -340,7 +383,8 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 ///
 /// becomes
 ///
-/// omp.single {
+/// %tmp = fir.alloca
+/// omp.single copyprivate(%tmp) {
 ///   %a = fir.allocmem
 ///   fir.store %a %tmp
 /// }
@@ -352,16 +396,15 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 /// }
 ///
 /// Note that we allocate temporary memory for values in omp.single's which need
-/// to be accessed in all threads in the closest omp.parallel
+/// to be accessed by all threads and broadcast them using single's copyprivate
 LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) {
   Location loc = wsOp->getLoc();
   IRMapping rootMapping;
 
   OpBuilder rootBuilder(wsOp);
 
-  // TODO We need something like an scf.execute here, but that is not registered
-  // so using omp.workshare as a placeholder. We need this op as our
-  // parallelizeRegion works on regions and not blocks.
+  // This operation is just a placeholder which will be erased later. We need it
+  // because our `parallelizeRegion` function works on regions and not blocks.
   omp::WorkshareOp newOp =
       rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands());
   if (!wsOp.getNowait())

>From a938304d6b9dea253b64668809fa323f3d55e329 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 16:53:22 +0900
Subject: [PATCH 35/36] Fix tests

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp |  6 +--
 .../Transforms/OpenMP/lower-workshare2.mlir   |  2 +
 .../Transforms/OpenMP/lower-workshare3.mlir   |  2 +-
 .../Transforms/OpenMP/lower-workshare4.mlir   |  3 ++
 .../Transforms/OpenMP/lower-workshare6.mlir   | 51 +++++++++++++++++++
 5 files changed, 60 insertions(+), 4 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare6.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index e921b80d0c571..9557dd200cace 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -123,10 +123,10 @@ static bool mustParallelizeOp(Operation *op) {
         //
         // Therefore, we skip if we encounter a nested omp.workshare.
         if (isa<omp::WorkshareOp>(op))
-          WalkResult::skip();
+          return WalkResult::skip();
         if (isa<omp::WorkshareLoopWrapperOp>(op))
-          WalkResult::interrupt();
-        WalkResult::advance();
+          return WalkResult::interrupt();
+        return WalkResult::advance();
       })
       .wasInterrupted();
 }
diff --git a/flang/test/Transforms/OpenMP/lower-workshare2.mlir b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
index 325a40d418445..940662e0bdccc 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare2.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
@@ -1,5 +1,7 @@
 // RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
+// Check that we correctly handle nowait
+
 // CHECK-LABEL:   func.func @nonowait
 func.func @nonowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
   // CHECK: omp.barrier
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
index afb41d95e7198..0921775751288 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -1,7 +1,7 @@
 // RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
 
-// tests if the correct values are stored
+// Check if we store the correct values
 
 func.func @wsfunc() {
   omp.parallel {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
index 0a70007a9e78d..44f68cd2ca365 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare4.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -1,5 +1,8 @@
 // RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
+// Check that we cleanup unused pure operations from either the parallel or
+// single regions
+
 func.func @wsfunc() {
   %a = fir.alloca i32
   omp.parallel {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare6.mlir b/flang/test/Transforms/OpenMP/lower-workshare6.mlir
new file mode 100644
index 0000000000000..b66f00a47c114
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare6.mlir
@@ -0,0 +1,51 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// Checks that the omp.workshare_loop_wrapper binds to the correct omp.workshare
+
+func.func @wsfunc() {
+  %c1 = arith.constant 1 : index
+  %c42 = arith.constant 42 : index
+  omp.parallel {
+    omp.workshare nowait {
+      omp.parallel {
+        omp.workshare nowait {
+          omp.workshare_loop_wrapper {
+            omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+              "test.test2"() : () -> ()
+              omp.yield
+            }
+            omp.terminator
+          }
+          omp.terminator
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @wsfunc() {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 42 : index
+// CHECK:           omp.parallel {
+// CHECK:             omp.single nowait {
+// CHECK:               omp.parallel {
+// CHECK:                 omp.wsloop nowait {
+// CHECK:                   omp.loop_nest (%[[VAL_2:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_1]]) inclusive step (%[[VAL_0]]) {
+// CHECK:                     "test.test2"() : () -> ()
+// CHECK:                     omp.yield
+// CHECK:                   }
+// CHECK:                   omp.terminator
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+

>From 8fffc3da94333824f0e9dc3833188287a8cf9d61 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 17:16:36 +0900
Subject: [PATCH 36/36] Fix test

---
 flang/test/Fir/basic-program.fir | 1 +
 1 file changed, 1 insertion(+)

diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index dda4f32872fef..6c2ff016f34ae 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -47,6 +47,7 @@ func.func @_QQmain() {
 // PASSES-NEXT:   LowerHLFIRIntrinsics
 // PASSES-NEXT:   BufferizeHLFIR
 // PASSES-NEXT:   ConvertHLFIRtoFIR
+// PASSES-NEXT:   LowerWorkshare
 // PASSES-NEXT:   CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd



More information about the llvm-branch-commits mailing list