@@ -0,0 +1,446 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// This file implements the lowering of omp.workshare to other omp constructs.
+// This pass is tasked with parallelizing the loops nested in
+// workshare.loop_wrapper while both the Fortran to mlir lowering and the hlfir
+// to fir lowering pipelines are responsible for emitting the
+// workshare.loop_wrapper ops where appropriate according to the
+// `shouldUseWorkshareLowering` function.
+#include <flang/Optimizer/Builder/FIRBuilder.h>
+#include <flang/Optimizer/Dialect/FIROps.h>
+#include <flang/Optimizer/Dialect/FIRType.h>
+#include <flang/Optimizer/HLFIR/HLFIROps.h>
+#include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/BreadthFirstIterator.h>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/ADT/SmallVectorExtras.h>
+#include <llvm/ADT/iterator_range.h>
+#include <llvm/Support/ErrorHandling.h>
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/IR/Visitors.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
+#include <variant>
+namespace flangomp {
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+#define DEBUG_TYPE "lower-workshare"
+using namespace mlir;
+namespace flangomp {
+// Checks for nesting pattern below as we need to avoid sharing the work of
+// statements which are nested in some constructs such as omp.critical or
+// another omp.parallel.
+// omp.workshare { // `wsOp`
+//   ...
+//     omp.T { // `parent`
+//       ...
+//         `op`
+template <typename T>
+static bool isNestedIn(omp::WorkshareOp wsOp, Operation *op) {
+  T parent = op->getParentOfType<T>();
+  if (!parent)
+    return false;
+  return wsOp->isProperAncestor(parent);
+bool shouldUseWorkshareLowering(Operation *op) {
+  auto parentWorkshare = op->getParentOfType<omp::WorkshareOp>();
+  if (!parentWorkshare)
+    return false;
+  if (isNestedIn<omp::CriticalOp>(parentWorkshare, op))
+    return false;
+  // 2.8.3  workshare Construct
+  // For a parallel construct, the construct is a unit of work with respect to
+  // the workshare construct. The statements contained in the parallel construct
+  // are executed by a new thread team.
+  if (isNestedIn<omp::ParallelOp>(parentWorkshare, op))
+    return false;
+  // 2.8.2  single Construct
+  // Binding The binding thread set for a single region is the current team. A
+  // single region binds to the innermost enclosing parallel region.
+  // Description Only one of the encountering threads will execute the
+  // structured block associated with the single construct.
+  if (isNestedIn<omp::SingleOp>(parentWorkshare, op))
+    return false;
+  return true;
+} // namespace flangomp
+namespace {
+struct SingleRegion {
+  Block::iterator begin, end;
+static bool mustParallelizeOp(Operation *op) {
+  return op
+      ->walk([&](Operation *nested) {
+        // We need to be careful not to pick up workshare.loop_wrapper in nested
+        // omp.parallel{omp.workshare} regions, i.e. make sure that `nested`
+        // binds to the workshare region we are currently handling.
+        //
+        // For example:
+        //
+        // omp.parallel {
+        //   omp.workshare { // currently handling this
+        //     omp.parallel {
+        //       omp.workshare { // nested workshare
+        //         omp.workshare.loop_wrapper {}
+        //
+        // Therefore, we skip if we encounter a nested omp.workshare.
+        if (isa<omp::WorkshareOp>(op))
+          return WalkResult::skip();
+        if (isa<omp::WorkshareLoopWrapperOp>(op))
+          return WalkResult::interrupt();
+        return WalkResult::advance();
+      })
+      .wasInterrupted();
+static bool isSafeToParallelize(Operation *op) {
+  return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
+         isMemoryEffectFree(op);
+/// Simple shallow copies suffice for our purposes in this pass, so we implement
+/// this simpler alternative to the full fledged `createCopyFunc` in the
+/// frontend
+static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
+                                         fir::FirOpBuilder builder) {
+  mlir::ModuleOp module = builder.getModule();
+  auto rt = cast<fir::ReferenceType>(varType);
+  mlir::Type eleTy = rt.getEleTy();
+  std::string copyFuncName =
+      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+  if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
+    return decl;
+  // create function
+  mlir::OpBuilder::InsertionGuard guard(builder);
+  mlir::OpBuilder modBuilder(module.getBodyRegion());
+  llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
+  auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
+  mlir::func::FuncOp funcOp =
+      modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
+  funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
+  builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
+                      {loc, loc});
+  builder.setInsertionPointToStart(&funcOp.getRegion().back());
+  Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(0));
+  builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(1));
+  builder.create<mlir::func::ReturnOp>(loc);
+  return funcOp;
+static bool isUserOutsideSR(Operation *user, Operation *parentOp,
+                            SingleRegion sr) {
+  while (user->getParentOp() != parentOp)
+    user = user->getParentOp();
+  return sr.begin->getBlock() != user->getBlock() ||
+         !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
+static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
+  Block *srBlock = sr.begin->getBlock();
+  Operation *parentOp = srBlock->getParentOp();
+  for (auto &use : v.getUses()) {
+    Operation *user = use.getOwner();
+    if (isUserOutsideSR(user, parentOp, sr))
+      return true;
+    // Results of nested users cannot be used outside of the SR
+    if (user->getBlock() != srBlock)
+      continue;
+    // A non-safe to parallelize operation will be handled separately
+    if (!isSafeToParallelize(user))
+      continue;
+    for (auto res : user->getResults())
+      if (isTransitivelyUsedOutside(res, sr))
+        return true;
+  }
+  return false;
+/// We clone pure operations in both the parallel and single blocks. this
+/// functions cleans them up if they end up with no uses
+static void cleanupBlock(Block *block) {
+  for (Operation &op : llvm::make_early_inc_range(
+           llvm::make_range(block->rbegin(), block->rend())))
+    if (isOpTriviallyDead(&op))
+      op.erase();
+static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
+                              IRMapping &rootMapping, Location loc,
+                              mlir::DominanceInfo &di) {
+  OpBuilder rootBuilder(sourceRegion.getContext());
+  ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
+  OpBuilder copyFuncBuilder(m.getBodyRegion());
+  fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
+  auto mapReloadedValue =
+      [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
+          OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
+    if (auto reloaded = rootMapping.lookupOrNull(v))
+      return nullptr;
+    Type ty = v.getType();
+    Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
+    singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+    Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
+    rootMapping.map(v, reloaded);
+    return alloc;
+  };
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
+                          OpBuilder singleBuilder,
+                          OpBuilder parallelBuilder) -> SmallVector<Value> {
+    IRMapping singleMapping = rootMapping;
+    SmallVector<Value> copyPrivate;
+    for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
+      if (isSafeToParallelize(&op)) {
+        singleBuilder.clone(op, singleMapping);
+        parallelBuilder.clone(op, rootMapping);
+      } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
+        auto hoisted =
+            cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
+        rootMapping.map(&*alloca, &*hoisted);
+        rootMapping.map(alloca.getResult(), hoisted.getResult());
+        copyPrivate.push_back(hoisted);
+      } else {
+        singleBuilder.clone(op, singleMapping);
+        // Prepare reloaded values for results of operations that cannot be
+        // safely parallelized and which are used after the region `sr`
+        for (auto res : op.getResults()) {
+          if (isTransitivelyUsedOutside(res, sr)) {
+            auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
+                                          parallelBuilder, singleMapping);
+            if (alloc)
+              copyPrivate.push_back(alloc);
+          }
+        }
+      }
+    }
+    singleBuilder.create<omp::TerminatorOp>(loc);
+    return copyPrivate;
+  };
+  for (Block &block : sourceRegion) {
+    Block *targetBlock = rootBuilder.createBlock(
+        &targetRegion, {}, block.getArgumentTypes(),
+        llvm::map_to_vector(block.getArguments(),
+                            [](BlockArgument arg) { return arg.getLoc(); }));
+    rootMapping.map(&block, targetBlock);
+    rootMapping.map(block.getArguments(), targetBlock->getArguments());
+  }
+  auto handleOneBlock = [&](Block &block) {
+    Block &targetBlock = *rootMapping.lookup(&block);
+    rootBuilder.setInsertionPointToStart(&targetBlock);
+    Operation *terminator = block.getTerminator();
+    SmallVector<std::variant<SingleRegion, Operation *>> regions;
+    auto it = block.begin();
+    auto getOneRegion = [&]() {
+      if (&*it == terminator)
+        return false;
+      if (mustParallelizeOp(&*it)) {
+        regions.push_back(&*it);
+        it++;
+        return true;
+      }
+      SingleRegion sr;
+      sr.begin = it;
+      while (&*it != terminator && !mustParallelizeOp(&*it))
+        it++;
+      sr.end = it;
+      assert(sr.begin != sr.end);
+      regions.push_back(sr);
+      return true;
+    };
+    while (getOneRegion())
+      ;
+    for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
+      bool isLast = i + 1 == regions.size();
+      if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+        OpBuilder singleBuilder(sourceRegion.getContext());
+        Block *singleBlock = new Block();
ivanradanov wrote:

The problem here is that until we construct this omp.single block we do not know how many operands the omp.single will have. So if we would like to construct the omp.single which will be the parent of this block we need to do two passes and the code becomes more complex. (createBlock does not support creating orphan blocks). After we construct this block we know what values we need to broadcast outside of it so we can construct the omp.single parent with the correct operands. 


