[llvm-branch-commits] [flang] [flang] Lower omp.workshare to other omp constructs (PR #101446)

Ivan R. Ivanov via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Aug 19 00:55:39 PDT 2024


https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/101446

>From 30e12940ef4b7c11d3ff749fc4af7d3e3075a6ed Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 22:06:55 +0900
Subject: [PATCH 01/19] [flang] Lower omp.workshare to other omp constructs

---
 flang/include/flang/Optimizer/OpenMP/Passes.h |   2 +
 .../include/flang/Optimizer/OpenMP/Passes.td  |   6 +-
 flang/include/flang/Tools/CLOptions.inc       |   1 +
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   1 +
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 259 ++++++++++++++++++
 .../Transforms/OpenMP/lower-workshare.mlir    |  81 ++++++
 .../Transforms/OpenMP/lower-workshare5.mlir   |  42 +++
 7 files changed, 391 insertions(+), 1 deletion(-)
 create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare.mlir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare5.mlir

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.h b/flang/include/flang/Optimizer/OpenMP/Passes.h
index 403d79667bf448..11fa4e59f891ea 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.h
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.h
@@ -25,6 +25,8 @@ namespace flangomp {
 #define GEN_PASS_REGISTRATION
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
 
+bool shouldUseWorkshareLowering(mlir::Operation *op);
+
 } // namespace flangomp
 
 #endif // FORTRAN_OPTIMIZER_OPENMP_PASSES_H
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index a192b9a33c1976..1c9d75d8cfaa18 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -1,4 +1,4 @@
-//===-- Passes.td - HLFIR pass definition file -------------*- tablegen -*-===//
+//===-- Passes.td - flang OpenMP pass definition -----------*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -37,4 +37,8 @@ def FunctionFiltering : Pass<"omp-function-filtering"> {
   ];
 }
 
+def LowerWorkshare : Pass<"lower-workshare"> {
+  let summary = "Lower workshare construct";
+}
+
 #endif //FORTRAN_OPTIMIZER_OPENMP_PASSES
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 05b2f31711add2..a565effebfa92e 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,6 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
+  pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index a8984d256b8f6a..2a38f157a851ce 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -4,6 +4,7 @@ add_flang_library(FlangOpenMPTransforms
   FunctionFiltering.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
+  LowerWorkshare.cpp
 
   DEPENDS
   FIRDialect
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
new file mode 100644
index 00000000000000..40975552d1fe33
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -0,0 +1,259 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Lower omp workshare construct.
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/iterator_range.h"
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKSHARE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workshare"
+
+using namespace mlir;
+
+namespace flangomp {
+bool shouldUseWorkshareLowering(Operation *op) {
+  auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
+  if (!workshare)
+    return false;
+  return workshare->getParentOfType<omp::ParallelOp>();
+}
+} // namespace flangomp
+
+namespace {
+
+struct SingleRegion {
+  Block::iterator begin, end;
+};
+
+static bool isSupportedByFirAlloca(Type ty) {
+  return !isa<fir::ReferenceType>(ty);
+}
+
+static bool isSafeToParallelize(Operation *op) {
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface) {
+    return false;
+  }
+  interface.getEffects(effects);
+  if (effects.empty())
+    return true;
+
+  return false;
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+///   %a = fir.allocmem
+///   omp.wsloop {}
+///   fir.call Assign %b %a
+///   fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// omp.single {
+///   %a = fir.allocmem
+///   fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.wsloop {}
+/// omp.single {
+///   fir.call Assign %b %a_reloaded
+///   fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which need
+/// to be accessed in all threads in the closest omp.parallel
+///
+/// TODO currently we need to be able to access the encompassing omp.parallel so
+/// that we can allocate temporaries accessible by all threads outside of it.
+/// In case we do not find it, we fall back to converting the omp.workshare to
+/// omp.single.
+/// To better handle this we should probably enable yielding values out of an
+/// omp.single which will be supported by the omp runtime.
+void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+  assert(wsOp.getRegion().getBlocks().size() == 1);
+
+  Location loc = wsOp->getLoc();
+
+  omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
+  if (!parallelOp) {
+    wsOp.emitWarning("cannot handle workshare, converting to single");
+    Operation *terminator = wsOp.getRegion().front().getTerminator();
+    wsOp->getBlock()->getOperations().splice(
+        wsOp->getIterator(), wsOp.getRegion().front().getOperations());
+    terminator->erase();
+    return;
+  }
+
+  OpBuilder allocBuilder(parallelOp);
+  OpBuilder rootBuilder(wsOp);
+  IRMapping rootMapping;
+
+  omp::SingleOp singleOp = nullptr;
+
+  auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
+                              IRMapping singleMapping) {
+    if (auto reloaded = rootMapping.lookupOrNull(v))
+      return;
+    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
+    Type ty = v.getType();
+    Value alloc, reloaded;
+    if (isSupportedByFirAlloca(ty)) {
+      alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
+      singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+      reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
+    } else {
+      auto one = allocBuilder.create<LLVM::ConstantOp>(
+          loc, allocBuilder.getI32Type(), 1);
+      alloc =
+          allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
+      Value toStore = singleBuilder
+                          .create<UnrealizedConversionCastOp>(
+                              loc, llvmPtrTy, singleMapping.lookup(v))
+                          .getResult(0);
+      singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
+      reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
+      reloaded =
+          rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
+              .getResult(0);
+    }
+    rootMapping.map(v, reloaded);
+  };
+
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
+    IRMapping singleMapping = rootMapping;
+
+    for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
+      singleBuilder.clone(op, singleMapping);
+      if (isSafeToParallelize(&op)) {
+        rootBuilder.clone(op, rootMapping);
+      } else {
+        // Prepare reloaded values for results of operations that cannot be
+        // safely parallelized and which are used after the region `sr`
+        for (auto res : op.getResults()) {
+          for (auto &use : res.getUses()) {
+            Operation *user = use.getOwner();
+            while (user->getParentOp() != wsOp)
+              user = user->getParentOp();
+            if (!user->isBeforeInBlock(&*sr.end)) {
+              // We need to reload
+              mapReloadedValue(use.get(), singleBuilder, singleMapping);
+            }
+          }
+        }
+      }
+    }
+    singleBuilder.create<omp::TerminatorOp>(loc);
+  };
+
+  Block *wsBlock = &wsOp.getRegion().front();
+  assert(wsBlock->getTerminator()->getNumOperands() == 0);
+  Operation *terminator = wsBlock->getTerminator();
+
+  SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
+
+  auto it = wsBlock->begin();
+  auto getSingleRegion = [&]() {
+    if (&*it == terminator)
+      return false;
+    if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
+      regions.push_back(pop);
+      it++;
+      return true;
+    }
+    SingleRegion sr;
+    sr.begin = it;
+    while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
+      it++;
+    sr.end = it;
+    assert(sr.begin != sr.end);
+    regions.push_back(sr);
+    return true;
+  };
+  while (getSingleRegion())
+    ;
+
+  for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
+    bool isLast = i + 1 == regions.size();
+    if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
+      omp::SingleOperands singleOperands;
+      if (isLast)
+        singleOperands.nowait = rootBuilder.getUnitAttr();
+      singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+      OpBuilder singleBuilder(singleOp);
+      singleBuilder.createBlock(&singleOp.getRegion());
+      moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
+    } else {
+      rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
+      if (!isLast)
+        rootBuilder.create<omp::BarrierOp>(loc);
+    }
+  }
+
+  if (!wsOp.getNowait())
+    rootBuilder.create<omp::BarrierOp>(loc);
+
+  wsOp->erase();
+
+  return;
+}
+
+class LowerWorksharePass
+    : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
+public:
+  void runOnOperation() override {
+    SmallPtrSet<Operation *, 8> parents;
+    getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
+      Operation *isolatedParent =
+          wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
+      parents.insert(isolatedParent);
+
+      lowerWorkshare(wsOp);
+    });
+
+    // Do folding
+    for (Operation *isolatedParent : parents) {
+      RewritePatternSet patterns(&getContext());
+      GreedyRewriteConfig config;
+      // prevent the pattern driver form merging blocks
+      config.enableRegionSimplification =
+          mlir::GreedySimplifyRegionLevel::Disabled;
+      if (failed(applyPatternsAndFoldGreedily(isolatedParent,
+                                              std::move(patterns), config))) {
+        emitError(isolatedParent->getLoc(), "error in lower workshare\n");
+        signalPassFailure();
+      }
+    }
+  }
+};
+} // namespace
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
new file mode 100644
index 00000000000000..a8d36443f08bda
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -0,0 +1,81 @@
+// RUN: fir-opt --lower-workshare %s | FileCheck %s
+
+module {
+// CHECK-LABEL:   func.func @simple(
+// CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
+// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           omp.parallel {
+// CHECK:             omp.single {
+// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]]#0 : !fir.ref<!fir.array<42xi32>> to !llvm.ptr
+// CHECK:               llvm.store %[[VAL_9]], %[[VAL_5]] : !llvm.ptr, !llvm.ptr
+// CHECK:               %[[VAL_10:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               fir.store %[[VAL_11]]#0 to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_12:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.ptr
+// CHECK:             %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.ptr to !fir.ref<!fir.array<42xi32>>
+// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_2]] : i32
+// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_14]] (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                 omp.yield
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.barrier
+// CHECK:             omp.single nowait {
+// CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.barrier
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+  func.func @simple(%arg0: !fir.ref<!fir.array<42xi32>>) {
+    omp.parallel {
+      omp.workshare {
+        %c42 = arith.constant 42 : index
+        %c1_i32 = arith.constant 1 : i32
+        %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+        %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+        %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+        %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+        %true = arith.constant true
+        %c1 = arith.constant 1 : index
+        omp.wsloop {
+          omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+            %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+            %8 = fir.load %7 : !fir.ref<i32>
+            %9 = arith.subi %8, %c1_i32 : i32
+            %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+            hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+            omp.yield
+          }
+          omp.terminator
+        }
+        %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+        hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+        fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+        omp.terminator
+      }
+      omp.terminator
+    }
+    return
+  }
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workshare5.mlir b/flang/test/Transforms/OpenMP/lower-workshare5.mlir
new file mode 100644
index 00000000000000..177f8aa8f86c7c
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare5.mlir
@@ -0,0 +1,42 @@
+// XFAIL: *
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// TODO we can lower these but we have no guarantee that the parent of
+// omp.workshare supports multi-block regions, thus we fail for now.
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+    ^bb1:
+      %c1 = arith.constant 1 : i32
+      cf.br ^bb3(%c1: i32)
+    ^bb3(%arg1: i32):
+      "test.test2"(%arg1) : (i32) -> ()
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+    ^bb1:
+      %c1 = arith.constant 1 : i32
+      cf.br ^bb3(%c1: i32)
+    ^bb2:
+      "test.test2"(%r) : (i32) -> ()
+      omp.terminator
+    ^bb3(%arg1: i32):
+      %r = "test.test2"(%arg1) : (i32) -> i32
+      cf.br ^bb2
+    }
+    omp.terminator
+  }
+  return
+}

>From cbb14172cb5b72719081cd623f0608a2b6186950 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:41:09 +0900
Subject: [PATCH 02/19] Change to workshare loop wrapper op

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 24 ++++++++++++-------
 .../Transforms/OpenMP/lower-workshare.mlir    |  5 ++--
 2 files changed, 18 insertions(+), 11 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 40975552d1fe33..cb342b60de4e8d 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -21,6 +21,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/iterator_range.h"
 
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
 #include <variant>
 
 namespace flangomp {
@@ -73,7 +74,7 @@ static bool isSafeToParallelize(Operation *op) {
 ///
 /// omp.workshare {
 ///   %a = fir.allocmem
-///   omp.wsloop {}
+///   omp.workshare_loop_wrapper {}
 ///   fir.call Assign %b %a
 ///   fir.freemem %a
 /// }
@@ -85,7 +86,7 @@ static bool isSafeToParallelize(Operation *op) {
 ///   fir.store %a %tmp
 /// }
 /// %a_reloaded = fir.load %tmp
-/// omp.wsloop {}
+/// omp.workshare_loop_wrapper {}
 /// omp.single {
 ///   fir.call Assign %b %a_reloaded
 ///   fir.freemem %a_reloaded
@@ -180,20 +181,20 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
   assert(wsBlock->getTerminator()->getNumOperands() == 0);
   Operation *terminator = wsBlock->getTerminator();
 
-  SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
+  SmallVector<std::variant<SingleRegion, omp::WorkshareLoopWrapperOp>> regions;
 
   auto it = wsBlock->begin();
   auto getSingleRegion = [&]() {
     if (&*it == terminator)
       return false;
-    if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
+    if (auto pop = dyn_cast<omp::WorkshareLoopWrapperOp>(&*it)) {
       regions.push_back(pop);
       it++;
       return true;
     }
     SingleRegion sr;
     sr.begin = it;
-    while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
+    while (&*it != terminator && !isa<omp::WorkshareLoopWrapperOp>(&*it))
       it++;
     sr.end = it;
     assert(sr.begin != sr.end);
@@ -214,9 +215,16 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
       singleBuilder.createBlock(&singleOp.getRegion());
       moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
     } else {
-      rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
-      if (!isLast)
-        rootBuilder.create<omp::BarrierOp>(loc);
+      omp::WsloopOperands wsloopOperands;
+      if (isLast)
+        wsloopOperands.nowait = rootBuilder.getUnitAttr();
+      auto wsloop =
+          rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
+      auto wslw = std::get<omp::WorkshareLoopWrapperOp>(loopOrSingle);
+      auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
+          rootBuilder.clone(*wslw, rootMapping));
+      wsloop.getRegion().takeBody(clonedWslw.getRegion());
+      clonedWslw->erase();
     }
   }
 
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index a8d36443f08bda..cb5791d35916a9 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -34,7 +34,6 @@ module {
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             omp.barrier
 // CHECK:             omp.single nowait {
 // CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
 // CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
@@ -56,7 +55,7 @@ module {
         %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
         %true = arith.constant true
         %c1 = arith.constant 1 : index
-        omp.wsloop {
+        "omp.workshare_loop_wrapper"() ({
           omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
             %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
             %8 = fir.load %7 : !fir.ref<i32>
@@ -66,7 +65,7 @@ module {
             omp.yield
           }
           omp.terminator
-        }
+        }) : () -> ()
         %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
         %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
         %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>

>From 5f4f9bf1267cff9f9f413bd3b667661aab75b9b3 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 16:47:27 +0900
Subject: [PATCH 03/19] Move single op declaration

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index cb342b60de4e8d..2322d2acbc0138 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -120,8 +120,6 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
   OpBuilder rootBuilder(wsOp);
   IRMapping rootMapping;
 
-  omp::SingleOp singleOp = nullptr;
-
   auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
                               IRMapping singleMapping) {
     if (auto reloaded = rootMapping.lookupOrNull(v))
@@ -210,7 +208,8 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
       omp::SingleOperands singleOperands;
       if (isLast)
         singleOperands.nowait = rootBuilder.getUnitAttr();
-      singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+      omp::SingleOp singleOp =
+          rootBuilder.create<omp::SingleOp>(loc, singleOperands);
       OpBuilder singleBuilder(singleOp);
       singleBuilder.createBlock(&singleOp.getRegion());
       moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);

>From 79ae18461e1909637f206d10ac286d8dcb5ef18c Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 2 Aug 2024 17:13:58 +0900
Subject: [PATCH 04/19] Schedule pass properly

---
 flang/include/flang/Tools/CLOptions.inc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index a565effebfa92e..b256dba3c92949 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,7 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  pm.addPass(flangomp::createLowerWorkshare());
+  addNestedPassToAllTopLevelOperations(pm, flangomp::createLowerWorkshare);
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed

>From 5911275afa122ebc77ef5f653799cb19ee5af187 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 00:30:40 +0900
Subject: [PATCH 05/19] Correctly handle nested nested loop nests to be
 parallelized by workshare

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 256 ++++++++++--------
 1 file changed, 138 insertions(+), 118 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 2322d2acbc0138..8e79d1401c01c6 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -19,9 +19,14 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/iterator_range.h"
 
+#include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/Visitors.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <variant>
 
 namespace flangomp {
@@ -52,90 +57,40 @@ static bool isSupportedByFirAlloca(Type ty) {
   return !isa<fir::ReferenceType>(ty);
 }
 
-static bool isSafeToParallelize(Operation *op) {
-  if (isa<fir::DeclareOp>(op))
-    return true;
-
-  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
-  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
-  if (!interface) {
-    return false;
-  }
-  interface.getEffects(effects);
-  if (effects.empty())
-    return true;
-
-  return false;
+static bool mustParallelizeOp(Operation *op) {
+  return op
+      ->walk(
+          [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
+      .wasInterrupted();
 }
 
-/// Lowers workshare to a sequence of single-thread regions and parallel loops
-///
-/// For example:
-///
-/// omp.workshare {
-///   %a = fir.allocmem
-///   omp.workshare_loop_wrapper {}
-///   fir.call Assign %b %a
-///   fir.freemem %a
-/// }
-///
-/// becomes
-///
-/// omp.single {
-///   %a = fir.allocmem
-///   fir.store %a %tmp
-/// }
-/// %a_reloaded = fir.load %tmp
-/// omp.workshare_loop_wrapper {}
-/// omp.single {
-///   fir.call Assign %b %a_reloaded
-///   fir.freemem %a_reloaded
-/// }
-///
-/// Note that we allocate temporary memory for values in omp.single's which need
-/// to be accessed in all threads in the closest omp.parallel
-///
-/// TODO currently we need to be able to access the encompassing omp.parallel so
-/// that we can allocate temporaries accessible by all threads outside of it.
-/// In case we do not find it, we fall back to converting the omp.workshare to
-/// omp.single.
-/// To better handle this we should probably enable yielding values out of an
-/// omp.single which will be supported by the omp runtime.
-void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
-  assert(wsOp.getRegion().getBlocks().size() == 1);
-
-  Location loc = wsOp->getLoc();
+static bool isSafeToParallelize(Operation *op) {
+  return isa<fir::DeclareOp>(op) || isPure(op);
+}
 
-  omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
-  if (!parallelOp) {
-    wsOp.emitWarning("cannot handle workshare, converting to single");
-    Operation *terminator = wsOp.getRegion().front().getTerminator();
-    wsOp->getBlock()->getOperations().splice(
-        wsOp->getIterator(), wsOp.getRegion().front().getOperations());
-    terminator->erase();
-    return;
-  }
-
-  OpBuilder allocBuilder(parallelOp);
-  OpBuilder rootBuilder(wsOp);
-  IRMapping rootMapping;
+static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
+                              IRMapping &rootMapping, Location loc) {
+  Operation *parentOp = sourceRegion.getParentOp();
+  OpBuilder rootBuilder(sourceRegion.getContext());
 
+  // TODO need to copyprivate the alloca's
   auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
                               IRMapping singleMapping) {
+    OpBuilder allocaBuilder(&targetRegion.front().front());
     if (auto reloaded = rootMapping.lookupOrNull(v))
       return;
-    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
+    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
     Value alloc, reloaded;
     if (isSupportedByFirAlloca(ty)) {
-      alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
+      alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
       singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
       reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
     } else {
-      auto one = allocBuilder.create<LLVM::ConstantOp>(
-          loc, allocBuilder.getI32Type(), 1);
+      auto one = allocaBuilder.create<LLVM::ConstantOp>(
+          loc, allocaBuilder.getI32Type(), 1);
       alloc =
-          allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
+          allocaBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
       Value toStore = singleBuilder
                           .create<UnrealizedConversionCastOp>(
                               loc, llvmPtrTy, singleMapping.lookup(v))
@@ -162,9 +117,10 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
         for (auto res : op.getResults()) {
           for (auto &use : res.getUses()) {
             Operation *user = use.getOwner();
-            while (user->getParentOp() != wsOp)
+            while (user->getParentOp() != parentOp)
               user = user->getParentOp();
-            if (!user->isBeforeInBlock(&*sr.end)) {
+            if (!(user->isBeforeInBlock(&*sr.end) &&
+                  sr.begin->isBeforeInBlock(user))) {
               // We need to reload
               mapReloadedValue(use.get(), singleBuilder, singleMapping);
             }
@@ -175,61 +131,125 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
     singleBuilder.create<omp::TerminatorOp>(loc);
   };
 
-  Block *wsBlock = &wsOp.getRegion().front();
-  assert(wsBlock->getTerminator()->getNumOperands() == 0);
-  Operation *terminator = wsBlock->getTerminator();
+  // TODO Need to handle these (clone them) in dominator tree order
+  for (Block &block : sourceRegion) {
+    rootBuilder.createBlock(
+        &targetRegion, {}, block.getArgumentTypes(),
+        llvm::map_to_vector(block.getArguments(),
+                            [](BlockArgument arg) { return arg.getLoc(); }));
+    Operation *terminator = block.getTerminator();
 
-  SmallVector<std::variant<SingleRegion, omp::WorkshareLoopWrapperOp>> regions;
+    SmallVector<std::variant<SingleRegion, Operation *>> regions;
 
-  auto it = wsBlock->begin();
-  auto getSingleRegion = [&]() {
-    if (&*it == terminator)
-      return false;
-    if (auto pop = dyn_cast<omp::WorkshareLoopWrapperOp>(&*it)) {
-      regions.push_back(pop);
-      it++;
+    auto it = block.begin();
+    auto getOneRegion = [&]() {
+      if (&*it == terminator)
+        return false;
+      if (mustParallelizeOp(&*it)) {
+        regions.push_back(&*it);
+        it++;
+        return true;
+      }
+      SingleRegion sr;
+      sr.begin = it;
+      while (&*it != terminator && !mustParallelizeOp(&*it))
+        it++;
+      sr.end = it;
+      assert(sr.begin != sr.end);
+      regions.push_back(sr);
       return true;
+    };
+    while (getOneRegion())
+      ;
+
+    for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
+      bool isLast = i + 1 == regions.size();
+      if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+        omp::SingleOperands singleOperands;
+        if (isLast)
+          singleOperands.nowait = rootBuilder.getUnitAttr();
+        omp::SingleOp singleOp =
+            rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+        OpBuilder singleBuilder(singleOp);
+        singleBuilder.createBlock(&singleOp.getRegion());
+        moveToSingle(std::get<SingleRegion>(opOrSingle), singleBuilder);
+      } else {
+        auto op = std::get<Operation *>(opOrSingle);
+        if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {
+          omp::WsloopOperands wsloopOperands;
+          if (isLast)
+            wsloopOperands.nowait = rootBuilder.getUnitAttr();
+          auto wsloop =
+              rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
+          auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
+              rootBuilder.clone(*wslw, rootMapping));
+          wsloop.getRegion().takeBody(clonedWslw.getRegion());
+          clonedWslw->erase();
+        } else {
+          assert(mustParallelizeOp(op));
+          Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
+          for (auto [region, clonedRegion] :
+               llvm::zip(op->getRegions(), cloned->getRegions()))
+            parallelizeRegion(region, clonedRegion, rootMapping, loc);
+        }
+      }
     }
-    SingleRegion sr;
-    sr.begin = it;
-    while (&*it != terminator && !isa<omp::WorkshareLoopWrapperOp>(&*it))
-      it++;
-    sr.end = it;
-    assert(sr.begin != sr.end);
-    regions.push_back(sr);
-    return true;
-  };
-  while (getSingleRegion())
-    ;
-
-  for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
-    bool isLast = i + 1 == regions.size();
-    if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
-      omp::SingleOperands singleOperands;
-      if (isLast)
-        singleOperands.nowait = rootBuilder.getUnitAttr();
-      omp::SingleOp singleOp =
-          rootBuilder.create<omp::SingleOp>(loc, singleOperands);
-      OpBuilder singleBuilder(singleOp);
-      singleBuilder.createBlock(&singleOp.getRegion());
-      moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
-    } else {
-      omp::WsloopOperands wsloopOperands;
-      if (isLast)
-        wsloopOperands.nowait = rootBuilder.getUnitAttr();
-      auto wsloop =
-          rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
-      auto wslw = std::get<omp::WorkshareLoopWrapperOp>(loopOrSingle);
-      auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
-          rootBuilder.clone(*wslw, rootMapping));
-      wsloop.getRegion().takeBody(clonedWslw.getRegion());
-      clonedWslw->erase();
-    }
+
+    rootBuilder.clone(*block.getTerminator(), rootMapping);
   }
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+///   %a = fir.allocmem
+///   omp.workshare_loop_wrapper {}
+///   fir.call Assign %b %a
+///   fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// omp.single {
+///   %a = fir.allocmem
+///   fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.workshare_loop_wrapper {}
+/// omp.single {
+///   fir.call Assign %b %a_reloaded
+///   fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which need
+/// to be accessed in all threads in the closest omp.parallel
+void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+  Location loc = wsOp->getLoc();
+  IRMapping rootMapping;
+
+  OpBuilder rootBuilder(wsOp);
+
+  // TODO We need something like an scf;execute here, but that is not registered
+  // so using fir.if for now but it looks like it does not support multiple
+  // blocks so it doesnt work for multi block case...
+  auto ifOp = rootBuilder.create<fir::IfOp>(
+      loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
+  ifOp.getThenRegion().front().erase();
+
+  parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);
+
+  Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
+  assert(isa<omp::TerminatorOp>(terminatorOp));
+  OpBuilder termBuilder(terminatorOp);
 
   if (!wsOp.getNowait())
-    rootBuilder.create<omp::BarrierOp>(loc);
+    termBuilder.create<omp::BarrierOp>(loc);
+
+  termBuilder.create<fir::ResultOp>(loc, ValueRange());
 
+  terminatorOp->erase();
   wsOp->erase();
 
   return;

>From 3604b3dd80e1e1b8347c66cc1d24e2ed72c998d9 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 00:33:57 +0900
Subject: [PATCH 06/19] Leave comments for shouldUseWorkshareLowering

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 21 +++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 8e79d1401c01c6..40dae0fd848ef8 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -40,10 +40,23 @@ using namespace mlir;
 
 namespace flangomp {
 bool shouldUseWorkshareLowering(Operation *op) {
-  auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
-  if (!workshare)
-    return false;
-  return workshare->getParentOfType<omp::ParallelOp>();
+  // TODO this is insufficient, as we could have
+  // omp.parallel {
+  //   omp.workshare {
+  //     omp.parallel {
+  //       hlfir.elemental {}
+  //
+  // Then this hlfir.elemental shall _not_ use the lowering for workshare
+  //
+  // Standard says:
+  //   For a parallel construct, the construct is a unit of work with respect to
+  //   the workshare construct. The statements contained in the parallel
+  //   construct are executed by a new thread team.
+  //
+  // TODO similarly for single, critical, etc. Need to think through the
+  // patterns and implement this function.
+  //
+  return op->getParentOfType<omp::WorkshareOp>();
 }
 } // namespace flangomp
 

>From ea26f32c964cf59ced7d735785c40f4050269e1f Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 13:14:38 +0900
Subject: [PATCH 07/19] Use copyprivate to scatter val from omp.single

TODO still need to implement copy function
TODO transitive check for usage outside of omp.single not imiplemented yet
---
 .../include/flang/Optimizer/OpenMP/Passes.td  |   3 +-
 flang/include/flang/Tools/CLOptions.inc       |   2 +-
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 138 ++++++++++++++----
 3 files changed, 109 insertions(+), 34 deletions(-)

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 1c9d75d8cfaa18..041240cad12eb3 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -37,7 +37,8 @@ def FunctionFiltering : Pass<"omp-function-filtering"> {
   ];
 }
 
-def LowerWorkshare : Pass<"lower-workshare"> {
+// Needs to be scheduled on Module as we create functions in it
+def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index b256dba3c92949..a565effebfa92e 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -345,7 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  addNestedPassToAllTopLevelOperations(pm, flangomp::createLowerWorkshare);
+  pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 40dae0fd848ef8..950737fccada79 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -8,25 +8,27 @@
 // Lower omp workshare construct.
 //===----------------------------------------------------------------------===//
 
-#include "flang/Optimizer/Dialect/FIROps.h"
-#include "flang/Optimizer/Dialect/FIRType.h"
-#include "flang/Optimizer/OpenMP/Passes.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVectorExtras.h"
-#include "llvm/ADT/iterator_range.h"
-
+#include <flang/Optimizer/Builder/FIRBuilder.h>
+#include <flang/Optimizer/Dialect/FIROps.h>
+#include <flang/Optimizer/Dialect/FIRType.h>
+#include <flang/Optimizer/HLFIR/HLFIROps.h>
+#include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/ADT/SmallVectorExtras.h>
+#include <llvm/ADT/iterator_range.h>
+#include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
-#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/PatternMatch.h>
 #include <mlir/IR/Visitors.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
+
 #include <variant>
 
 namespace flangomp {
@@ -71,6 +73,8 @@ static bool isSupportedByFirAlloca(Type ty) {
 }
 
 static bool mustParallelizeOp(Operation *op) {
+  // TODO as in shouldUseWorkshareLowering we be careful not to pick up
+  // workshare_loop_wrapper in nested omp.parallel ops
   return op
       ->walk(
           [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
@@ -78,7 +82,33 @@ static bool mustParallelizeOp(Operation *op) {
 }
 
 static bool isSafeToParallelize(Operation *op) {
-  return isa<fir::DeclareOp>(op) || isPure(op);
+  return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
+         isMemoryEffectFree(op);
+}
+
+static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
+                                         fir::FirOpBuilder builder) {
+  mlir::ModuleOp module = builder.getModule();
+  mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
+
+  std::string copyFuncName =
+      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+
+  if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
+    return decl;
+  // create function
+  mlir::OpBuilder::InsertionGuard guard(builder);
+  mlir::OpBuilder modBuilder(module.getBodyRegion());
+  llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
+  auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
+  mlir::func::FuncOp funcOp =
+      modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
+  funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
+  builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
+                      {loc, loc});
+  builder.setInsertionPointToStart(&funcOp.getRegion().back());
+  builder.create<mlir::func::ReturnOp>(loc);
+  return funcOp;
 }
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
@@ -86,19 +116,23 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
   Operation *parentOp = sourceRegion.getParentOp();
   OpBuilder rootBuilder(sourceRegion.getContext());
 
+  ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
+  OpBuilder copyFuncBuilder(m.getBodyRegion());
+  fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
+
   // TODO need to copyprivate the alloca's
-  auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
-                              IRMapping singleMapping) {
-    OpBuilder allocaBuilder(&targetRegion.front().front());
+  auto mapReloadedValue =
+      [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
+          OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
     if (auto reloaded = rootMapping.lookupOrNull(v))
-      return;
+      return nullptr;
     Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
     Value alloc, reloaded;
     if (isSupportedByFirAlloca(ty)) {
       alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
       singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
-      reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
+      reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
     } else {
       auto one = allocaBuilder.create<LLVM::ConstantOp>(
           loc, allocaBuilder.getI32Type(), 1);
@@ -109,21 +143,25 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               loc, llvmPtrTy, singleMapping.lookup(v))
                           .getResult(0);
       singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
-      reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
+      reloaded = parallelBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
       reloaded =
-          rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
+          parallelBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
               .getResult(0);
     }
     rootMapping.map(v, reloaded);
+    return alloc;
   };
 
-  auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
+                          OpBuilder singleBuilder,
+                          OpBuilder parallelBuilder) -> SmallVector<Value> {
     IRMapping singleMapping = rootMapping;
+    SmallVector<Value> copyPrivate;
 
     for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
       singleBuilder.clone(op, singleMapping);
       if (isSafeToParallelize(&op)) {
-        rootBuilder.clone(op, rootMapping);
+        parallelBuilder.clone(op, rootMapping);
       } else {
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
@@ -132,16 +170,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
             Operation *user = use.getOwner();
             while (user->getParentOp() != parentOp)
               user = user->getParentOp();
-            if (!(user->isBeforeInBlock(&*sr.end) &&
-                  sr.begin->isBeforeInBlock(user))) {
-              // We need to reload
-              mapReloadedValue(use.get(), singleBuilder, singleMapping);
+            // TODO we need to look at transitively used vals
+            if (true || !(user->isBeforeInBlock(&*sr.end) &&
+                          sr.begin->isBeforeInBlock(user))) {
+              auto alloc =
+                  mapReloadedValue(use.get(), allocaBuilder, singleBuilder,
+                                   parallelBuilder, singleMapping);
+              if (alloc)
+                copyPrivate.push_back(alloc);
             }
           }
         }
       }
     }
     singleBuilder.create<omp::TerminatorOp>(loc);
+    return copyPrivate;
   };
 
   // TODO Need to handle these (clone them) in dominator tree order
@@ -178,14 +221,45 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
       bool isLast = i + 1 == regions.size();
       if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+        OpBuilder singleBuilder(sourceRegion.getContext());
+        Block *singleBlock = new Block();
+        singleBuilder.setInsertionPointToStart(singleBlock);
+
+        OpBuilder allocaBuilder(sourceRegion.getContext());
+        Block *allocaBlock = new Block();
+        allocaBuilder.setInsertionPointToStart(allocaBlock);
+
+        OpBuilder parallelBuilder(sourceRegion.getContext());
+        Block *parallelBlock = new Block();
+        parallelBuilder.setInsertionPointToStart(parallelBlock);
+
         omp::SingleOperands singleOperands;
         if (isLast)
           singleOperands.nowait = rootBuilder.getUnitAttr();
+        auto insPtAtSingle = rootBuilder.saveInsertionPoint();
+        singleOperands.copyprivateVars =
+            moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
+                         singleBuilder, parallelBuilder);
+        for (auto var : singleOperands.copyprivateVars) {
+          Type ty;
+          if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
+            ty = firAlloca.getAllocatedType();
+          } else {
+            llvm_unreachable("unexpected");
+          }
+          mlir::func::FuncOp funcOp =
+              createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
+          singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
+        }
         omp::SingleOp singleOp =
             rootBuilder.create<omp::SingleOp>(loc, singleOperands);
-        OpBuilder singleBuilder(singleOp);
-        singleBuilder.createBlock(&singleOp.getRegion());
-        moveToSingle(std::get<SingleRegion>(opOrSingle), singleBuilder);
+        singleOp.getRegion().push_back(singleBlock);
+        rootBuilder.getInsertionBlock()->getOperations().splice(
+            rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
+        targetRegion.front().getOperations().splice(
+            singleOp->getIterator(), allocaBlock->getOperations());
+        delete allocaBlock;
+        delete parallelBlock;
       } else {
         auto op = std::get<Operation *>(opOrSingle);
         if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {

>From add3b79ee577266b38b4b01a9d3dfe7d00ee6fa6 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 13:39:49 +0900
Subject: [PATCH 08/19] Transitively check for users outisde of single op

TODO need to implement copy func
TODO need to hoist allocas outside of single regions
---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 51 ++++++++++++++-----
 1 file changed, 37 insertions(+), 14 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 950737fccada79..2e88d852ff2cba 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -111,6 +111,38 @@ static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
   return funcOp;
 }
 
+static bool isUserOutsideSR(Operation *user, Operation *parentOp,
+                            SingleRegion sr) {
+  while (user->getParentOp() != parentOp)
+    user = user->getParentOp();
+  return sr.begin->getBlock() != user->getBlock() ||
+         !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
+}
+
+static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
+  Block *srBlock = sr.begin->getBlock();
+  Operation *parentOp = srBlock->getParentOp();
+
+  for (auto &use : v.getUses()) {
+    Operation *user = use.getOwner();
+    if (isUserOutsideSR(user, parentOp, sr))
+      return true;
+
+    // Results of nested users cannot be used outside of the SR
+    if (user->getBlock() != srBlock)
+      continue;
+
+    // A non-safe to parallelize operation will be handled separately
+    if (!isSafeToParallelize(user))
+      continue;
+
+    for (auto res : user->getResults())
+      if (isTransitivelyUsedOutside(res, sr))
+        return true;
+  }
+  return false;
+}
+
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
   Operation *parentOp = sourceRegion.getParentOp();
@@ -166,19 +198,11 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
         for (auto res : op.getResults()) {
-          for (auto &use : res.getUses()) {
-            Operation *user = use.getOwner();
-            while (user->getParentOp() != parentOp)
-              user = user->getParentOp();
-            // TODO we need to look at transitively used vals
-            if (true || !(user->isBeforeInBlock(&*sr.end) &&
-                          sr.begin->isBeforeInBlock(user))) {
-              auto alloc =
-                  mapReloadedValue(use.get(), allocaBuilder, singleBuilder,
-                                   parallelBuilder, singleMapping);
-              if (alloc)
-                copyPrivate.push_back(alloc);
-            }
+          if (isTransitivelyUsedOutside(res, sr)) {
+            auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
+                                          parallelBuilder, singleMapping);
+            if (alloc)
+              copyPrivate.push_back(alloc);
           }
         }
       }
@@ -236,7 +260,6 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         omp::SingleOperands singleOperands;
         if (isLast)
           singleOperands.nowait = rootBuilder.getUnitAttr();
-        auto insPtAtSingle = rootBuilder.saveInsertionPoint();
         singleOperands.copyprivateVars =
             moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
                          singleBuilder, parallelBuilder);

>From 92c1e7787f772333fbca37f1b2c83b3f1fb6b257 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 14:30:43 +0900
Subject: [PATCH 09/19] Add tests

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp |  25 +-
 .../Transforms/OpenMP/lower-workshare.mlir    | 230 +++++++++++++-----
 2 files changed, 188 insertions(+), 67 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 2e88d852ff2cba..30af2556cf4cae 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -18,6 +18,7 @@
 #include <llvm/ADT/iterator_range.h>
 #include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
 #include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
 #include <mlir/IR/BuiltinOps.h>
@@ -75,6 +76,14 @@ static bool isSupportedByFirAlloca(Type ty) {
 static bool mustParallelizeOp(Operation *op) {
   // TODO as in shouldUseWorkshareLowering we be careful not to pick up
   // workshare_loop_wrapper in nested omp.parallel ops
+  //
+  // e.g.
+  //
+  // omp.parallel {
+  //   omp.workshare {
+  //     omp.parallel {
+  //       omp.workshare {
+  //         omp.workshare_loop_wrapper {}
   return op
       ->walk(
           [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
@@ -89,10 +98,14 @@ static bool isSafeToParallelize(Operation *op) {
 static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
                                          fir::FirOpBuilder builder) {
   mlir::ModuleOp module = builder.getModule();
-  mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
-
-  std::string copyFuncName =
-      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+  std::string copyFuncName;
+  if (auto rt = dyn_cast<fir::ReferenceType>(varType)) {
+    mlir::Type eleTy = rt.getEleTy();
+    copyFuncName =
+        fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+  } else {
+    copyFuncName = "_workshare_copy_llvm_ptr";
+  }
 
   if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
     return decl;
@@ -145,9 +158,7 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
-  Operation *parentOp = sourceRegion.getParentOp();
   OpBuilder rootBuilder(sourceRegion.getContext());
-
   ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
   OpBuilder copyFuncBuilder(m.getBodyRegion());
   fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
@@ -268,7 +279,7 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
             ty = firAlloca.getAllocatedType();
           } else {
-            llvm_unreachable("unexpected");
+            ty = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
           }
           mlir::func::FuncOp funcOp =
               createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index cb5791d35916a9..19123e71cacf60 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -1,80 +1,190 @@
-// RUN: fir-opt --lower-workshare %s | FileCheck %s
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
-module {
-// CHECK-LABEL:   func.func @simple(
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.parallel {
+    omp.workshare {
+      %c42 = arith.constant 42 : index
+      %c1_i32 = arith.constant 1 : i32
+      %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+      %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+      %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+      %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+      %true = arith.constant true
+      %c1 = arith.constant 1 : index
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+          %8 = fir.load %7 : !fir.ref<i32>
+          %9 = arith.subi %8, %c1_i32 : i32
+          %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+          hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+      %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+      hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+      fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+
+// -----
+
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.workshare {
+    %c1_i32 = arith.constant 1 : i32
+    %alloc = fir.alloca i32
+    fir.store %c1_i32 to %alloc : !fir.ref<i32>
+    %c42 = arith.constant 42 : index
+    %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+    %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+    %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+    %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+    %true = arith.constant true
+    %c1 = arith.constant 1 : index
+    "omp.workshare_loop_wrapper"() ({
+      omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+        %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+        %8 = fir.load %7 : !fir.ref<i32>
+        %ld = fir.load %alloc : !fir.ref<i32>
+        %n8 = arith.subi %8, %ld : i32
+        %9 = arith.subi %n8, %c1_i32 : i32
+        %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+        hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+        omp.yield
+      }
+      omp.terminator
+    }) : () -> ()
+    %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+    "test.test1"(%alloc) : (!fir.ref<i32>) -> ()
+    hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+    fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+    omp.terminator
+  }
+  return
+}
+
+
+// CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
 // CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK:           %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
-// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           %[[VAL_4:.*]] = arith.constant true
 // CHECK:           omp.parallel {
-// CHECK:             omp.single {
-// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]]#0 : !fir.ref<!fir.array<42xi32>> to !llvm.ptr
-// CHECK:               llvm.store %[[VAL_9]], %[[VAL_5]] : !llvm.ptr, !llvm.ptr
-// CHECK:               %[[VAL_10:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               fir.store %[[VAL_11]]#0 to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             fir.if %[[VAL_4]] {
+// CHECK:               %[[VAL_5:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:                 %[[VAL_6:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:                 %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_6]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:                 %[[VAL_8:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:                 fir.store %[[VAL_8]] to %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:                 %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_6]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               omp.wsloop {
+// CHECK:                 omp.loop_nest (%[[VAL_14:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                   %[[VAL_15:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_14]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                   %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
+// CHECK:                   %[[VAL_17:.*]] = arith.subi %[[VAL_16]], %[[VAL_2]] : i32
+// CHECK:                   %[[VAL_18:.*]] = hlfir.designate %[[VAL_13]]#0 (%[[VAL_14]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                   hlfir.assign %[[VAL_17]] to %[[VAL_18]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                   omp.yield
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.single nowait {
+// CHECK:                 hlfir.assign %[[VAL_13]]#0 to %[[VAL_11]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:                 fir.freemem %[[VAL_13]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.barrier
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_workshare_copy_llvm_ptr(
+// CHECK-SAME:                                                %[[VAL_0:.*]]: !llvm.ptr,
+// CHECK-SAME:                                                %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @wsfunc(
+// CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_5:.*]] = arith.constant true
+// CHECK:           fir.if %[[VAL_5]] {
+// CHECK:             %[[VAL_6:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
+// CHECK:             %[[VAL_7:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_6]] -> @_workshare_copy_llvm_ptr : !llvm.ptr, %[[VAL_7]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               %[[VAL_8:.*]] = fir.alloca i32
+// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !fir.ref<i32> to !llvm.ptr
+// CHECK:               llvm.store %[[VAL_9]], %[[VAL_6]] : !llvm.ptr, !llvm.ptr
+// CHECK:               fir.store %[[VAL_3]] to %[[VAL_8]] : !fir.ref<i32>
+// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_12:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_12]] to %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             %[[VAL_12:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.ptr
-// CHECK:             %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.ptr to !fir.ref<!fir.array<42xi32>>
-// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_14:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr -> !llvm.ptr
+// CHECK:             %[[VAL_15:.*]] = builtin.unrealized_conversion_cast %[[VAL_14]] : !llvm.ptr to !fir.ref<i32>
+// CHECK:             %[[VAL_16:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_17:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_16]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_18]](%[[VAL_16]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_2]] : i32
-// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_14]] (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.loop_nest (%[[VAL_20:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_17]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_23:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_24:.*]] = arith.subi %[[VAL_22]], %[[VAL_23]] : i32
+// CHECK:                 %[[VAL_25:.*]] = arith.subi %[[VAL_24]], %[[VAL_3]] : i32
+// CHECK:                 %[[VAL_26:.*]] = hlfir.designate %[[VAL_19]]#0 (%[[VAL_20]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_25]] to %[[VAL_26]] temporary_lhs : i32, !fir.ref<i32>
 // CHECK:                 omp.yield
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.single nowait {
-// CHECK:               hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
+// CHECK:               "test.test1"(%[[VAL_15]]) : (!fir.ref<i32>) -> ()
+// CHECK:               hlfir.assign %[[VAL_19]]#0 to %[[VAL_17]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_19]]#0 : !fir.heap<!fir.array<42xi32>>
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.barrier
-// CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
-  func.func @simple(%arg0: !fir.ref<!fir.array<42xi32>>) {
-    omp.parallel {
-      omp.workshare {
-        %c42 = arith.constant 42 : index
-        %c1_i32 = arith.constant 1 : i32
-        %0 = fir.shape %c42 : (index) -> !fir.shape<1>
-        %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-        %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-        %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-        %true = arith.constant true
-        %c1 = arith.constant 1 : index
-        "omp.workshare_loop_wrapper"() ({
-          omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
-            %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-            %8 = fir.load %7 : !fir.ref<i32>
-            %9 = arith.subi %8, %c1_i32 : i32
-            %10 = hlfir.designate %3#0 (%arg1)  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-            hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
-            omp.yield
-          }
-          omp.terminator
-        }) : () -> ()
-        %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
-        hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-        fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
-        omp.terminator
-      }
-      omp.terminator
-    }
-    return
-  }
-}
+

>From 6fe32271a39c9381c7de68704e9d939388c5dfff Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 15:50:58 +0900
Subject: [PATCH 10/19] Hoist allocas

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 10 ++-
 .../Transforms/OpenMP/lower-workshare.mlir    | 69 +++++++++----------
 2 files changed, 41 insertions(+), 38 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 30af2556cf4cae..d0cd235d3eb079 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -163,7 +163,6 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
   OpBuilder copyFuncBuilder(m.getBodyRegion());
   fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
 
-  // TODO need to copyprivate the alloca's
   auto mapReloadedValue =
       [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
           OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
@@ -202,10 +201,17 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     SmallVector<Value> copyPrivate;
 
     for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
-      singleBuilder.clone(op, singleMapping);
       if (isSafeToParallelize(&op)) {
+        singleBuilder.clone(op, singleMapping);
         parallelBuilder.clone(op, rootMapping);
+      } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
+        auto hoisted =
+            cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
+        rootMapping.map(&*alloca, &*hoisted);
+        rootMapping.map(alloca.getResult(), hoisted.getResult());
+        copyPrivate.push_back(hoisted);
       } else {
+        singleBuilder.clone(op, singleMapping);
         // Prepare reloaded values for results of operations that cannot be
         // safely parallelized and which are used after the region `sr`
         for (auto res : op.getResults()) {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 19123e71cacf60..b78cfd80e17acb 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -1,5 +1,7 @@
 // RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
+// checks:
+// nowait on final omp.single
 func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   omp.parallel {
     omp.workshare {
@@ -37,6 +39,8 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // -----
 
+// checks:
+// fir.alloca hoisted out and copyprivate'd
 func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   omp.workshare {
     %c1_i32 = arith.constant 1 : i32
@@ -73,7 +77,6 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   return
 }
 
-
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
@@ -130,9 +133,9 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           return
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_workshare_copy_llvm_ptr(
-// CHECK-SAME:                                                %[[VAL_0:.*]]: !llvm.ptr,
-// CHECK-SAME:                                                %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK-LABEL:   func.func private @_workshare_copy_i32(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: !fir.ref<i32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: !fir.ref<i32>) {
 // CHECK:           return
 // CHECK:         }
 
@@ -141,46 +144,40 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK:           %[[VAL_5:.*]] = arith.constant true
-// CHECK:           fir.if %[[VAL_5]] {
-// CHECK:             %[[VAL_6:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
-// CHECK:             %[[VAL_7:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:             omp.single copyprivate(%[[VAL_6]] -> @_workshare_copy_llvm_ptr : !llvm.ptr, %[[VAL_7]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:               %[[VAL_8:.*]] = fir.alloca i32
-// CHECK:               %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !fir.ref<i32> to !llvm.ptr
-// CHECK:               llvm.store %[[VAL_9]], %[[VAL_6]] : !llvm.ptr, !llvm.ptr
-// CHECK:               fir.store %[[VAL_3]] to %[[VAL_8]] : !fir.ref<i32>
-// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_12:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               fir.store %[[VAL_12]] to %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_4:.*]] = arith.constant true
+// CHECK:           fir.if %[[VAL_4]] {
+// CHECK:             %[[VAL_5:.*]] = fir.alloca i32
+// CHECK:             %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_6]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i32>
+// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_9:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_9]] to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             %[[VAL_14:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr -> !llvm.ptr
-// CHECK:             %[[VAL_15:.*]] = builtin.unrealized_conversion_cast %[[VAL_14]] : !llvm.ptr to !fir.ref<i32>
-// CHECK:             %[[VAL_16:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:             %[[VAL_17:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_16]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:             %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_18]](%[[VAL_16]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_11:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
 // CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_20:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_17]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_23:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_24:.*]] = arith.subi %[[VAL_22]], %[[VAL_23]] : i32
-// CHECK:                 %[[VAL_25:.*]] = arith.subi %[[VAL_24]], %[[VAL_3]] : i32
-// CHECK:                 %[[VAL_26:.*]] = hlfir.designate %[[VAL_19]]#0 (%[[VAL_20]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_25]] to %[[VAL_26]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = fir.load %[[VAL_5]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_19:.*]] = arith.subi %[[VAL_17]], %[[VAL_18]] : i32
+// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_3]] : i32
+// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_20]] to %[[VAL_21]] temporary_lhs : i32, !fir.ref<i32>
 // CHECK:                 omp.yield
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.single nowait {
-// CHECK:               "test.test1"(%[[VAL_15]]) : (!fir.ref<i32>) -> ()
-// CHECK:               hlfir.assign %[[VAL_19]]#0 to %[[VAL_17]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_19]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:               "test.test1"(%[[VAL_5]]) : (!fir.ref<i32>) -> ()
+// CHECK:               hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
 // CHECK:               omp.terminator
 // CHECK:             }
 // CHECK:             omp.barrier

>From a9d8578f5bc00c199338666866126dbfe86d3089 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 15:52:23 +0900
Subject: [PATCH 11/19] More tests

---
 .../Transforms/OpenMP/lower-workshare2.mlir   | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare2.mlir

diff --git a/flang/test/Transforms/OpenMP/lower-workshare2.mlir b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
new file mode 100644
index 00000000000000..325a40d4184453
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
@@ -0,0 +1,21 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @nonowait
+func.func @nonowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  // CHECK: omp.barrier
+  omp.workshare {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @nowait
+func.func @nowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  // CHECK-NOT: omp.barrier
+  omp.workshare nowait {
+    omp.terminator
+  }
+  return
+}

>From 8395991ea4e1e034c9d107641c6b4119008aa403 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:13:39 +0900
Subject: [PATCH 12/19] Emit body for copy func

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 44 +++++--------------
 .../Transforms/OpenMP/lower-workshare.mlir    |  6 +++
 2 files changed, 17 insertions(+), 33 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index d0cd235d3eb079..20f45296a8159a 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -69,10 +69,6 @@ struct SingleRegion {
   Block::iterator begin, end;
 };
 
-static bool isSupportedByFirAlloca(Type ty) {
-  return !isa<fir::ReferenceType>(ty);
-}
-
 static bool mustParallelizeOp(Operation *op) {
   // TODO as in shouldUseWorkshareLowering we be careful not to pick up
   // workshare_loop_wrapper in nested omp.parallel ops
@@ -98,14 +94,10 @@ static bool isSafeToParallelize(Operation *op) {
 static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
                                          fir::FirOpBuilder builder) {
   mlir::ModuleOp module = builder.getModule();
-  std::string copyFuncName;
-  if (auto rt = dyn_cast<fir::ReferenceType>(varType)) {
-    mlir::Type eleTy = rt.getEleTy();
-    copyFuncName =
-        fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
-  } else {
-    copyFuncName = "_workshare_copy_llvm_ptr";
-  }
+  auto rt = cast<fir::ReferenceType>(varType);
+  mlir::Type eleTy = rt.getEleTy();
+  std::string copyFuncName =
+      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
 
   if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
     return decl;
@@ -120,6 +112,10 @@ static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
   builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
                       {loc, loc});
   builder.setInsertionPointToStart(&funcOp.getRegion().back());
+
+  Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(0));
+  builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(1));
+
   builder.create<mlir::func::ReturnOp>(loc);
   return funcOp;
 }
@@ -168,28 +164,10 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
     if (auto reloaded = rootMapping.lookupOrNull(v))
       return nullptr;
-    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
     Type ty = v.getType();
-    Value alloc, reloaded;
-    if (isSupportedByFirAlloca(ty)) {
-      alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
-      singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
-      reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
-    } else {
-      auto one = allocaBuilder.create<LLVM::ConstantOp>(
-          loc, allocaBuilder.getI32Type(), 1);
-      alloc =
-          allocaBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
-      Value toStore = singleBuilder
-                          .create<UnrealizedConversionCastOp>(
-                              loc, llvmPtrTy, singleMapping.lookup(v))
-                          .getResult(0);
-      singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
-      reloaded = parallelBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
-      reloaded =
-          parallelBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
-              .getResult(0);
-    }
+    Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
+    singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+    Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
     rootMapping.map(v, reloaded);
     return alloc;
   };
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index b78cfd80e17acb..997bc8d79f9b3f 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -80,6 +80,8 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
 // CHECK:           return
 // CHECK:         }
 
@@ -130,12 +132,16 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK-LABEL:   func.func private @_workshare_copy_heap_42xi32(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
 // CHECK-SAME:                                                   %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
 // CHECK:           return
 // CHECK:         }
 
 // CHECK-LABEL:   func.func private @_workshare_copy_i32(
 // CHECK-SAME:                                           %[[VAL_0:.*]]: !fir.ref<i32>,
 // CHECK-SAME:                                           %[[VAL_1:.*]]: !fir.ref<i32>) {
+// CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<i32>
 // CHECK:           return
 // CHECK:         }
 

>From 8bc89ab0ed9dd3607fe4bfb325418122febede09 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 16:59:34 +0900
Subject: [PATCH 13/19] Test the tmp storing logic

---
 .../Transforms/OpenMP/lower-workshare.mlir    |  2 -
 .../Transforms/OpenMP/lower-workshare3.mlir   | 74 +++++++++++++++++++
 2 files changed, 74 insertions(+), 2 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare3.mlir

diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 997bc8d79f9b3f..063d3865065e01 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -36,7 +36,6 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
   return
 }
 
-
 // -----
 
 // checks:
@@ -190,4 +189,3 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
-
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
new file mode 100644
index 00000000000000..84eded94503282
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -0,0 +1,74 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+
+// tests if the correct values are stored
+
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+  omp.parallel {
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK: fir.alloca
+  // CHECK-NOT: fir.alloca
+    omp.workshare {
+
+      %t1 = "test.test1"() : () -> i32
+      // CHECK: %[[T1:.*]] = "test.test1"
+      // CHECK: fir.store %[[T1]]
+      %t2 = "test.test2"() : () -> i32
+      // CHECK: %[[T2:.*]] = "test.test2"
+      // CHECK: fir.store %[[T2]]
+      %t3 = "test.test3"() : () -> i32
+      // CHECK: %[[T3:.*]] = "test.test3"
+      // CHECK-NOT: fir.store %[[T3]]
+      %t4 = "test.test4"() : () -> i32
+      // CHECK: %[[T4:.*]] = "test.test4"
+      // CHECK: fir.store %[[T4]]
+      %t5 = "test.test5"() : () -> i32
+      // CHECK: %[[T5:.*]] = "test.test5"
+      // CHECK: fir.store %[[T5]]
+      %t6 = "test.test6"() : () -> i32
+      // CHECK: %[[T6:.*]] = "test.test6"
+      // CHECK-NOT: fir.store %[[T6]]
+
+
+      "test.test1"(%t1) : (i32) -> ()
+      "test.test1"(%t2) : (i32) -> ()
+      "test.test1"(%t3) : (i32) -> ()
+
+      %true = arith.constant true
+      fir.if %true {
+        "test.test2"(%t3) : (i32) -> ()
+      }
+
+      %c1_i32 = arith.constant 1 : i32
+
+      %t5_pure_use = arith.addi %t5, %c1_i32 : i32
+
+      %t6_mem_effect_use = "test.test8"(%t6) : (i32) -> i32
+      // CHECK: %[[T6_USE:.*]] = "test.test8"
+      // CHECK: fir.store %[[T6_USE]]
+
+      %c42 = arith.constant 42 : index
+      %c1 = arith.constant 1 : index
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          "test.test10"(%t1) : (i32) -> ()
+          "test.test10"(%t5_pure_use) : (i32) -> ()
+          "test.test10"(%t6_mem_effect_use) : (i32) -> ()
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+
+      "test.test10"(%t2) : (i32) -> ()
+      fir.if %true {
+        "test.test10"(%t4) : (i32) -> ()
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}

>From a2fc6fa60690a38ed534443db0c0c8b8b8d4799a Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 20:24:36 +0900
Subject: [PATCH 14/19] Clean up trivially dead ops

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 32 ++++-------
 .../Transforms/OpenMP/lower-workshare3.mlir   |  2 +-
 .../Transforms/OpenMP/lower-workshare4.mlir   | 55 +++++++++++++++++++
 3 files changed, 68 insertions(+), 21 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare4.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 20f45296a8159a..a147db2cb5d59a 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -152,6 +152,14 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
   return false;
 }
 
+/// We clone pure operations in both the parallel and single blocks. this
+/// functions cleans them up if they end up with no uses
+static void cleanupBlock(Block *block) {
+  for (Operation &op : llvm::make_early_inc_range(*block))
+    if (isOpTriviallyDead(&op))
+      op.erase();
+}
+
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
                               IRMapping &rootMapping, Location loc) {
   OpBuilder rootBuilder(sourceRegion.getContext());
@@ -258,13 +266,8 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
         singleOperands.copyprivateVars =
             moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
                          singleBuilder, parallelBuilder);
+        cleanupBlock(singleBlock);
         for (auto var : singleOperands.copyprivateVars) {
-          Type ty;
-          if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
-            ty = firAlloca.getAllocatedType();
-          } else {
-            ty = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
-          }
           mlir::func::FuncOp funcOp =
               createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
           singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
@@ -302,6 +305,9 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 
     rootBuilder.clone(*block.getTerminator(), rootMapping);
   }
+
+  for (Block &targetBlock : targetRegion)
+    cleanupBlock(&targetBlock);
 }
 
 /// Lowers workshare to a sequence of single-thread regions and parallel loops
@@ -372,20 +378,6 @@ class LowerWorksharePass
 
       lowerWorkshare(wsOp);
     });
-
-    // Do folding
-    for (Operation *isolatedParent : parents) {
-      RewritePatternSet patterns(&getContext());
-      GreedyRewriteConfig config;
-      // prevent the pattern driver form merging blocks
-      config.enableRegionSimplification =
-          mlir::GreedySimplifyRegionLevel::Disabled;
-      if (failed(applyPatternsAndFoldGreedily(isolatedParent,
-                                              std::move(patterns), config))) {
-        emitError(isolatedParent->getLoc(), "error in lower workshare\n");
-        signalPassFailure();
-      }
-    }
   }
 };
 } // namespace
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
index 84eded94503282..aee95a464a31bd 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -3,7 +3,7 @@
 
 // tests if the correct values are stored
 
-func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+func.func @wsfunc() {
   omp.parallel {
   // CHECK: fir.alloca
   // CHECK: fir.alloca
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
new file mode 100644
index 00000000000000..6cff0075b4fe50
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -0,0 +1,55 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+func.func @wsfunc() {
+  %a = fir.alloca i32
+  omp.parallel {
+    omp.workshare {
+      %t1 = "test.test1"() : () -> i32
+
+      %c1 = arith.constant 1 : index
+      %c42 = arith.constant 42 : index
+
+      %c2 = arith.constant 2 : index
+      "test.test3"(%c2) : (index) -> ()
+
+      "omp.workshare_loop_wrapper"() ({
+        omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+          "test.test2"() : () -> ()
+          omp.yield
+        }
+        omp.terminator
+      }) : () -> ()
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @wsfunc() {
+// CHECK:           %[[VAL_0:.*]] = fir.alloca i32
+// CHECK:           omp.parallel {
+// CHECK:             %[[VAL_1:.*]] = arith.constant true
+// CHECK:             fir.if %[[VAL_1]] {
+// CHECK:               omp.single {
+// CHECK:                 %[[VAL_2:.*]] = "test.test1"() : () -> i32
+// CHECK:                 %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK:                 "test.test3"(%[[VAL_3]]) : (index) -> ()
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_5:.*]] = arith.constant 42 : index
+// CHECK:               omp.wsloop nowait {
+// CHECK:                 omp.loop_nest (%[[VAL_6:.*]]) : index = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_4]]) {
+// CHECK:                   "test.test2"() : () -> ()
+// CHECK:                   omp.yield
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.barrier
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+

>From 1658a875460572bfe449840966b518e423594ac3 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 21:55:14 +0900
Subject: [PATCH 15/19] Only handle single-block regions for now

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp |  80 +++++----
 .../Transforms/OpenMP/lower-workshare.mlir    | 154 +++++++++---------
 .../Transforms/OpenMP/lower-workshare4.mlir   |  31 ++--
 3 files changed, 143 insertions(+), 122 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index a147db2cb5d59a..5998489c13d382 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -13,12 +13,14 @@
 #include <flang/Optimizer/Dialect/FIRType.h>
 #include <flang/Optimizer/HLFIR/HLFIROps.h>
 #include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/BreadthFirstIterator.h>
 #include <llvm/ADT/STLExtras.h>
 #include <llvm/ADT/SmallVectorExtras.h>
 #include <llvm/ADT/iterator_range.h>
 #include <llvm/Support/ErrorHandling.h>
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
 #include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <mlir/Dialect/SCF/IR/SCF.h>
 #include <mlir/IR/BuiltinOps.h>
@@ -161,7 +163,8 @@ static void cleanupBlock(Block *block) {
 }
 
 static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
-                              IRMapping &rootMapping, Location loc) {
+                              IRMapping &rootMapping, Location loc,
+                              mlir::DominanceInfo &di) {
   OpBuilder rootBuilder(sourceRegion.getContext());
   ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
   OpBuilder copyFuncBuilder(m.getBodyRegion());
@@ -214,14 +217,19 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     return copyPrivate;
   };
 
-  // TODO Need to handle these (clone them) in dominator tree order
   for (Block &block : sourceRegion) {
-    rootBuilder.createBlock(
+    Block *targetBlock = rootBuilder.createBlock(
         &targetRegion, {}, block.getArgumentTypes(),
         llvm::map_to_vector(block.getArguments(),
                             [](BlockArgument arg) { return arg.getLoc(); }));
-    Operation *terminator = block.getTerminator();
+    rootMapping.map(&block, targetBlock);
+    rootMapping.map(block.getArguments(), targetBlock->getArguments());
+  }
 
+  auto handleOneBlock = [&](Block &block) {
+    Block &targetBlock = *rootMapping.lookup(&block);
+    rootBuilder.setInsertionPointToStart(&targetBlock);
+    Operation *terminator = block.getTerminator();
     SmallVector<std::variant<SingleRegion, Operation *>> regions;
 
     auto it = block.begin();
@@ -298,12 +306,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
           Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
           for (auto [region, clonedRegion] :
                llvm::zip(op->getRegions(), cloned->getRegions()))
-            parallelizeRegion(region, clonedRegion, rootMapping, loc);
+            parallelizeRegion(region, clonedRegion, rootMapping, loc, di);
         }
       }
     }
 
     rootBuilder.clone(*block.getTerminator(), rootMapping);
+  };
+
+  if (sourceRegion.hasOneBlock()) {
+    handleOneBlock(sourceRegion.front());
+  } else {
+    auto &domTree = di.getDomTree(&sourceRegion);
+    for (auto node : llvm::breadth_first(domTree.getRootNode())) {
+      handleOneBlock(*node->getBlock());
+    }
   }
 
   for (Block &targetBlock : targetRegion)
@@ -336,47 +353,46 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 ///
 /// Note that we allocate temporary memory for values in omp.single's which need
 /// to be accessed in all threads in the closest omp.parallel
-void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) {
   Location loc = wsOp->getLoc();
   IRMapping rootMapping;
 
   OpBuilder rootBuilder(wsOp);
 
-  // TODO We need something like an scf;execute here, but that is not registered
-  // so using fir.if for now but it looks like it does not support multiple
-  // blocks so it doesnt work for multi block case...
-  auto ifOp = rootBuilder.create<fir::IfOp>(
-      loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
-  ifOp.getThenRegion().front().erase();
-
-  parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);
-
-  Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
-  assert(isa<omp::TerminatorOp>(terminatorOp));
-  OpBuilder termBuilder(terminatorOp);
-
+  // TODO We need something like an scf.execute here, but that is not registered
+  // so using omp.workshare as a placeholder. We need this op as our
+  // parallelizeRegion works on regions and not blocks.
+  omp::WorkshareOp newOp =
+      rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands());
   if (!wsOp.getNowait())
-    termBuilder.create<omp::BarrierOp>(loc);
-
-  termBuilder.create<fir::ResultOp>(loc, ValueRange());
-
-  terminatorOp->erase();
+    rootBuilder.create<omp::BarrierOp>(loc);
+
+  parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc, di);
+
+  if (wsOp.getRegion().getBlocks().size() != 1)
+    return failure();
+
+  // Inline the contents of the placeholder workshare op into its parent block.
+  Block *theBlock = &newOp.getRegion().front();
+  Operation *term = theBlock->getTerminator();
+  Block *parentBlock = wsOp->getBlock();
+  parentBlock->getOperations().splice(newOp->getIterator(),
+                                      theBlock->getOperations());
+  assert(term->getNumOperands() == 0);
+  term->erase();
+  newOp->erase();
   wsOp->erase();
-
-  return;
+  return success();
 }
 
 class LowerWorksharePass
     : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
 public:
   void runOnOperation() override {
-    SmallPtrSet<Operation *, 8> parents;
+    mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>();
     getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
-      Operation *isolatedParent =
-          wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
-      parents.insert(isolatedParent);
-
-      lowerWorkshare(wsOp);
+      if (failed(lowerWorkshare(wsOp, di)))
+        signalPassFailure();
     });
   }
 };
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index 063d3865065e01..b31e951223d56f 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -86,43 +86,46 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant true
 // CHECK:           omp.parallel {
-// CHECK:             fir.if %[[VAL_4]] {
-// CHECK:               %[[VAL_5:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:               omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:                 %[[VAL_6:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:                 %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_6]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:                 %[[VAL_8:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:                 fir.store %[[VAL_8]] to %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:                 %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_6]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               %[[VAL_10:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               omp.wsloop {
-// CHECK:                 omp.loop_nest (%[[VAL_14:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                   %[[VAL_15:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_14]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                   %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
-// CHECK:                   %[[VAL_17:.*]] = arith.subi %[[VAL_16]], %[[VAL_2]] : i32
-// CHECK:                   %[[VAL_18:.*]] = hlfir.designate %[[VAL_13]]#0 (%[[VAL_14]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                   hlfir.assign %[[VAL_17]] to %[[VAL_18]] temporary_lhs : i32, !fir.ref<i32>
-// CHECK:                   omp.yield
-// CHECK:                 }
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               omp.single nowait {
-// CHECK:                 hlfir.assign %[[VAL_13]]#0 to %[[VAL_11]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:                 fir.freemem %[[VAL_13]]#0 : !fir.heap<!fir.array<42xi32>>
-// CHECK:                 omp.terminator
+// CHECK:             %[[VAL_1:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:               %[[VAL_2:.*]] = arith.constant 42 : index
+// CHECK:               %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:               %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:               %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:               fir.store %[[VAL_5]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:               %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_7:.*]] = arith.constant 42 : index
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : i32
+// CHECK:             %[[VAL_9:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_9]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_11:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_11]](%[[VAL_9]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_13:.*]] = arith.constant true
+// CHECK:             %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_14]]) to (%[[VAL_7]]) inclusive step (%[[VAL_14]]) {
+// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_10]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:                 %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_8]] : i32
+// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:                 hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:                 omp.yield
 // CHECK:               }
-// CHECK:               omp.barrier
+// CHECK:               omp.terminator
 // CHECK:             }
+// CHECK:             omp.single nowait {
+// CHECK:               %[[VAL_20:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               %[[VAL_21:.*]] = fir.insert_value %[[VAL_20]], %[[VAL_13]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:               hlfir.assign %[[VAL_12]]#0 to %[[VAL_10]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:               fir.freemem %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_22:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             %[[VAL_23:.*]] = fir.insert_value %[[VAL_22]], %[[VAL_13]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             omp.barrier
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return
@@ -146,46 +149,51 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
 
 // CHECK-LABEL:   func.func @wsfunc(
 // CHECK-SAME:                      %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_4:.*]] = arith.constant true
-// CHECK:           fir.if %[[VAL_4]] {
-// CHECK:             %[[VAL_5:.*]] = fir.alloca i32
-// CHECK:             %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
-// CHECK:             omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_6]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
-// CHECK:               fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i32>
-// CHECK:               %[[VAL_7:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:               %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:               %[[VAL_9:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
-// CHECK:               fir.store %[[VAL_9]] to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:               %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:               omp.terminator
-// CHECK:             }
-// CHECK:             %[[VAL_11:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
-// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
-// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
-// CHECK:             %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
-// CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_16:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_18:.*]] = fir.load %[[VAL_5]] : !fir.ref<i32>
-// CHECK:                 %[[VAL_19:.*]] = arith.subi %[[VAL_17]], %[[VAL_18]] : i32
-// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_3]] : i32
-// CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_15]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
-// CHECK:                 hlfir.assign %[[VAL_20]] to %[[VAL_21]] temporary_lhs : i32, !fir.ref<i32>
-// CHECK:                 omp.yield
-// CHECK:               }
-// CHECK:               omp.terminator
-// CHECK:             }
-// CHECK:             omp.single nowait {
-// CHECK:               "test.test1"(%[[VAL_5]]) : (!fir.ref<i32>) -> ()
-// CHECK:               hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
-// CHECK:               fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
-// CHECK:               omp.terminator
+// CHECK:           %[[VAL_1:.*]] = fir.alloca i32
+// CHECK:           %[[VAL_2:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK:           omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_2]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_1]] : !fir.ref<i32>
+// CHECK:             %[[VAL_4:.*]] = arith.constant 42 : index
+// CHECK:             %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:             %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:             %[[VAL_7:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK:             fir.store %[[VAL_7]] to %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:             %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_7]](%[[VAL_5]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           %[[VAL_9:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_10:.*]] = arith.constant 42 : index
+// CHECK:           %[[VAL_11:.*]] = fir.shape %[[VAL_10]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_13:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK:           %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK:           %[[VAL_15:.*]] = arith.constant true
+// CHECK:           %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:           omp.wsloop {
+// CHECK:             omp.loop_nest (%[[VAL_17:.*]]) : index = (%[[VAL_16]]) to (%[[VAL_10]]) inclusive step (%[[VAL_16]]) {
+// CHECK:               %[[VAL_18:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_17]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:               %[[VAL_19:.*]] = fir.load %[[VAL_18]] : !fir.ref<i32>
+// CHECK:               %[[VAL_20:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
+// CHECK:               %[[VAL_21:.*]] = arith.subi %[[VAL_19]], %[[VAL_20]] : i32
+// CHECK:               %[[VAL_22:.*]] = arith.subi %[[VAL_21]], %[[VAL_9]] : i32
+// CHECK:               %[[VAL_23:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_17]])  : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK:               hlfir.assign %[[VAL_22]] to %[[VAL_23]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK:               omp.yield
 // CHECK:             }
-// CHECK:             omp.barrier
+// CHECK:             omp.terminator
 // CHECK:           }
+// CHECK:           omp.single nowait {
+// CHECK:             %[[VAL_24:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             %[[VAL_25:.*]] = fir.insert_value %[[VAL_24]], %[[VAL_15]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:             "test.test1"(%[[VAL_1]]) : (!fir.ref<i32>) -> ()
+// CHECK:             hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK:             fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           %[[VAL_26:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:           %[[VAL_27:.*]] = fir.insert_value %[[VAL_26]], %[[VAL_15]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK:           omp.barrier
 // CHECK:           return
 // CHECK:         }
+
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
index 6cff0075b4fe50..d695a1c354517b 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare4.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -29,25 +29,22 @@ func.func @wsfunc() {
 // CHECK-LABEL:   func.func @wsfunc() {
 // CHECK:           %[[VAL_0:.*]] = fir.alloca i32
 // CHECK:           omp.parallel {
-// CHECK:             %[[VAL_1:.*]] = arith.constant true
-// CHECK:             fir.if %[[VAL_1]] {
-// CHECK:               omp.single {
-// CHECK:                 %[[VAL_2:.*]] = "test.test1"() : () -> i32
-// CHECK:                 %[[VAL_3:.*]] = arith.constant 2 : index
-// CHECK:                 "test.test3"(%[[VAL_3]]) : (index) -> ()
-// CHECK:                 omp.terminator
-// CHECK:               }
-// CHECK:               %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:               %[[VAL_5:.*]] = arith.constant 42 : index
-// CHECK:               omp.wsloop nowait {
-// CHECK:                 omp.loop_nest (%[[VAL_6:.*]]) : index = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_4]]) {
-// CHECK:                   "test.test2"() : () -> ()
-// CHECK:                   omp.yield
-// CHECK:                 }
-// CHECK:                 omp.terminator
+// CHECK:             omp.single {
+// CHECK:               %[[VAL_1:.*]] = "test.test1"() : () -> i32
+// CHECK:               %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:               "test.test3"(%[[VAL_2]]) : (index) -> ()
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_4:.*]] = arith.constant 42 : index
+// CHECK:             omp.wsloop nowait {
+// CHECK:               omp.loop_nest (%[[VAL_5:.*]]) : index = (%[[VAL_3]]) to (%[[VAL_4]]) inclusive step (%[[VAL_3]]) {
+// CHECK:                 "test.test2"() : () -> ()
+// CHECK:                 omp.yield
 // CHECK:               }
-// CHECK:               omp.barrier
+// CHECK:               omp.terminator
 // CHECK:             }
+// CHECK:             omp.barrier
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return

>From b5aadf53b32d186e5f31f561edf13c3b4f41005b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 6 Aug 2024 13:52:20 +0900
Subject: [PATCH 16/19] Fix tests for custom assembly for loop wrapper

---
 flang/test/Transforms/OpenMP/lower-workshare.mlir  | 8 ++++----
 flang/test/Transforms/OpenMP/lower-workshare3.mlir | 4 ++--
 flang/test/Transforms/OpenMP/lower-workshare4.mlir | 4 ++--
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
index b31e951223d56f..9347863dc4a609 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -13,7 +13,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
       %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
       %true = arith.constant true
       %c1 = arith.constant 1 : index
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
           %8 = fir.load %7 : !fir.ref<i32>
@@ -23,7 +23,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
       %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
       %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
       %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
@@ -52,7 +52,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
     %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
     %true = arith.constant true
     %c1 = arith.constant 1 : index
-    "omp.workshare_loop_wrapper"() ({
+    omp.workshare_loop_wrapper {
       omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
         %7 = hlfir.designate %1#0 (%arg1)  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
         %8 = fir.load %7 : !fir.ref<i32>
@@ -64,7 +64,7 @@ func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
         omp.yield
       }
       omp.terminator
-    }) : () -> ()
+    }
     %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
     %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
     %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
index aee95a464a31bd..afb41d95e7198e 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -52,7 +52,7 @@ func.func @wsfunc() {
 
       %c42 = arith.constant 42 : index
       %c1 = arith.constant 1 : index
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           "test.test10"(%t1) : (i32) -> ()
           "test.test10"(%t5_pure_use) : (i32) -> ()
@@ -60,7 +60,7 @@ func.func @wsfunc() {
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
 
       "test.test10"(%t2) : (i32) -> ()
       fir.if %true {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
index d695a1c354517b..0a70007a9e78dd 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare4.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -12,13 +12,13 @@ func.func @wsfunc() {
       %c2 = arith.constant 2 : index
       "test.test3"(%c2) : (index) -> ()
 
-      "omp.workshare_loop_wrapper"() ({
+      omp.workshare_loop_wrapper {
         omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
           "test.test2"() : () -> ()
           omp.yield
         }
         omp.terminator
-      }) : () -> ()
+      }
       omp.terminator
     }
     omp.terminator

>From 08584697956e12be8a83e93de795e4594aabd1e3 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 14:43:50 +0900
Subject: [PATCH 17/19] Only run the lower workshare pass if openmp is enabled

---
 flang/include/flang/Tools/CLOptions.inc      |  7 ++++---
 flang/include/flang/Tools/CrossToolHelpers.h |  1 +
 flang/lib/Frontend/FrontendActions.cpp       | 10 +++++++++-
 flang/tools/bbc/bbc.cpp                      |  5 ++++-
 flang/tools/tco/tco.cpp                      |  1 +
 5 files changed, 19 insertions(+), 5 deletions(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index a565effebfa92e..fbca7b6838ada7 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -328,7 +328,7 @@ inline void createDefaultFIROptimizerPassPipeline(
 /// \param optLevel - optimization level used for creating FIR optimization
 ///   passes pipeline
 inline void createHLFIRToFIRPassPipeline(
-    mlir::PassManager &pm, llvm::OptimizationLevel optLevel = defaultOptLevel) {
+    mlir::PassManager &pm, bool enableOpenMP, llvm::OptimizationLevel optLevel = defaultOptLevel) {
   if (optLevel.isOptimizingForSpeed()) {
     addCanonicalizerPassWithoutRegionSimplification(pm);
     addNestedPassToAllTopLevelOperations(
@@ -345,7 +345,8 @@ inline void createHLFIRToFIRPassPipeline(
   pm.addPass(hlfir::createLowerHLFIRIntrinsics());
   pm.addPass(hlfir::createBufferizeHLFIR());
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  pm.addPass(flangomp::createLowerWorkshare());
+  if (enableOpenMP)
+    pm.addPass(flangomp::createLowerWorkshare());
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
@@ -416,7 +417,7 @@ inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
 ///   passes pipeline
 inline void createMLIRToLLVMPassPipeline(mlir::PassManager &pm,
     MLIRToLLVMPassPipelineConfig &config, llvm::StringRef inputFilename = {}) {
-  fir::createHLFIRToFIRPassPipeline(pm, config.OptLevel);
+  fir::createHLFIRToFIRPassPipeline(pm, config.EnableOpenMP, config.OptLevel);
 
   // Add default optimizer pass pipeline.
   fir::createDefaultFIROptimizerPassPipeline(pm, config);
diff --git a/flang/include/flang/Tools/CrossToolHelpers.h b/flang/include/flang/Tools/CrossToolHelpers.h
index 1d890fd8e1f6f9..dd258864ff7f22 100644
--- a/flang/include/flang/Tools/CrossToolHelpers.h
+++ b/flang/include/flang/Tools/CrossToolHelpers.h
@@ -123,6 +123,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks {
       false; ///< Set no-signed-zeros-fp-math attribute for functions.
   bool UnsafeFPMath = false; ///< Set unsafe-fp-math attribute for functions.
   bool NSWOnLoopVarInc = false; ///< Add nsw flag to loop variable increments.
+  bool EnableOpenMP = false; ///< Enable OpenMP lowering.
 };
 
 struct OffloadModuleOpts {
diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index 5c86bd947ce73f..db5c5649337528 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -711,7 +711,11 @@ void CodeGenAction::lowerHLFIRToFIR() {
   pm.enableVerifier(/*verifyPasses=*/true);
 
   // Create the pass pipeline
-  fir::createHLFIRToFIRPassPipeline(pm, level);
+  fir::createHLFIRToFIRPassPipeline(
+      pm,
+      ci.getInvocation().getFrontendOpts().features.IsEnabled(
+          Fortran::common::LanguageFeature::OpenMP),
+      level);
   (void)mlir::applyPassManagerCLOptions(pm);
 
   if (!mlir::succeeded(pm.run(*mlirModule))) {
@@ -824,6 +828,10 @@ void CodeGenAction::generateLLVMIR() {
     config.VScaleMax = vsr->second;
   }
 
+  if (ci.getInvocation().getFrontendOpts().features.IsEnabled(
+          Fortran::common::LanguageFeature::OpenMP))
+    config.EnableOpenMP = true;
+
   if (ci.getInvocation().getLoweringOpts().getNSWOnLoopVarInc())
     config.NSWOnLoopVarInc = true;
 
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index 07eef065daf6f4..681b23883df44a 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -429,7 +429,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
 
     if (emitFIR && useHLFIR) {
       // lower HLFIR to FIR
-      fir::createHLFIRToFIRPassPipeline(pm, llvm::OptimizationLevel::O2);
+      fir::createHLFIRToFIRPassPipeline(pm, enableOpenMP,
+                                        llvm::OptimizationLevel::O2);
       if (mlir::failed(pm.run(mlirModule))) {
         llvm::errs() << "FATAL: lowering from HLFIR to FIR failed";
         return mlir::failure();
@@ -444,6 +445,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
 
     // Add O2 optimizer pass pipeline.
     MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
+    if (enableOpenMP)
+      config.EnableOpenMP = true;
     config.NSWOnLoopVarInc = setNSW;
     fir::registerDefaultInlinerPass(config);
     fir::createDefaultFIROptimizerPassPipeline(pm, config);
diff --git a/flang/tools/tco/tco.cpp b/flang/tools/tco/tco.cpp
index a8c64333109aeb..06892cdc3f6a80 100644
--- a/flang/tools/tco/tco.cpp
+++ b/flang/tools/tco/tco.cpp
@@ -138,6 +138,7 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
       return mlir::failure();
   } else {
     MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
+    config.EnableOpenMP = true;  // assume the input contains OpenMP
     config.AliasAnalysis = true; // enabled when optimizing for speed
     if (codeGenLLVM) {
       // Run only CodeGen passes.

>From 0cbfe5078640aa9ca799e1d609345bf13803dab9 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 16:16:38 +0900
Subject: [PATCH 18/19] Implement some missing functionality

---
 flang/include/flang/Optimizer/OpenMP/Passes.h |   3 +
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 113 ++++++++++++------
 2 files changed, 81 insertions(+), 35 deletions(-)

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.h b/flang/include/flang/Optimizer/OpenMP/Passes.h
index 11fa4e59f891ea..feb395f1a12dbd 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.h
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.h
@@ -25,6 +25,9 @@ namespace flangomp {
 #define GEN_PASS_REGISTRATION
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
 
+/// Impelements the logic specified in the 2.8.3  workshare Construct section of
+/// the OpenMP standard which specifies what statements or constructs shall be
+/// divided into units of work.
 bool shouldUseWorkshareLowering(mlir::Operation *op);
 
 } // namespace flangomp
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index 5998489c13d382..e921b80d0c571e 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -5,7 +5,15 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-// Lower omp workshare construct.
+//
+// This file implements the lowering of omp.workshare to other omp constructs.
+//
+// This pass is tasked with parallelizing the loops nested in
+// workshare_loop_wrapper while both the Fortran to mlir lowering and the hlfir
+// to fir lowering pipelines are responsible for emitting the
+// workshare_loop_wrapper ops where appropriate according to the
+// `shouldUseWorkshareLowering` function.
+//
 //===----------------------------------------------------------------------===//
 
 #include <flang/Optimizer/Builder/FIRBuilder.h>
@@ -44,25 +52,52 @@ namespace flangomp {
 using namespace mlir;
 
 namespace flangomp {
+
+// Checks for nesting pattern below as we need to avoid sharing the work of
+// statements which are nested in some constructs such as omp.critical or
+// another omp.parallel.
+//
+// omp.workshare { // `wsOp`
+//   ...
+//     omp.T { // `parent`
+//       ...
+//         `op`
+//
+template <typename T>
+static bool isNestedIn(omp::WorkshareOp wsOp, Operation *op) {
+  T parent = op->getParentOfType<T>();
+  if (!parent)
+    return false;
+  return wsOp->isProperAncestor(parent);
+}
+
 bool shouldUseWorkshareLowering(Operation *op) {
-  // TODO this is insufficient, as we could have
-  // omp.parallel {
-  //   omp.workshare {
-  //     omp.parallel {
-  //       hlfir.elemental {}
-  //
-  // Then this hlfir.elemental shall _not_ use the lowering for workshare
-  //
-  // Standard says:
-  //   For a parallel construct, the construct is a unit of work with respect to
-  //   the workshare construct. The statements contained in the parallel
-  //   construct are executed by a new thread team.
-  //
-  // TODO similarly for single, critical, etc. Need to think through the
-  // patterns and implement this function.
-  //
-  return op->getParentOfType<omp::WorkshareOp>();
+  auto parentWorkshare = op->getParentOfType<omp::WorkshareOp>();
+
+  if (!parentWorkshare)
+    return false;
+
+  if (isNestedIn<omp::CriticalOp>(parentWorkshare, op))
+    return false;
+
+  // 2.8.3  workshare Construct
+  // For a parallel construct, the construct is a unit of work with respect to
+  // the workshare construct. The statements contained in the parallel construct
+  // are executed by a new thread team.
+  if (isNestedIn<omp::ParallelOp>(parentWorkshare, op))
+    return false;
+
+  // 2.8.2  single Construct
+  // Binding The binding thread set for a single region is the current team. A
+  // single region binds to the innermost enclosing parallel region.
+  // Description Only one of the encountering threads will execute the
+  // structured block associated with the single construct.
+  if (isNestedIn<omp::SingleOp>(parentWorkshare, op))
+    return false;
+
+  return true;
 }
+
 } // namespace flangomp
 
 namespace {
@@ -72,19 +107,27 @@ struct SingleRegion {
 };
 
 static bool mustParallelizeOp(Operation *op) {
-  // TODO as in shouldUseWorkshareLowering we be careful not to pick up
-  // workshare_loop_wrapper in nested omp.parallel ops
-  //
-  // e.g.
-  //
-  // omp.parallel {
-  //   omp.workshare {
-  //     omp.parallel {
-  //       omp.workshare {
-  //         omp.workshare_loop_wrapper {}
   return op
-      ->walk(
-          [](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
+      ->walk([&](Operation *nested) {
+        // We need to be careful not to pick up workshare_loop_wrapper in nested
+        // omp.parallel{omp.workshare} regions, i.e. make sure that `nested`
+        // binds to the workshare region we are currently handling.
+        //
+        // For example:
+        //
+        // omp.parallel {
+        //   omp.workshare { // currently handling this
+        //     omp.parallel {
+        //       omp.workshare { // nested workshare
+        //         omp.workshare_loop_wrapper {}
+        //
+        // Therefore, we skip if we encounter a nested omp.workshare.
+        if (isa<omp::WorkshareOp>(op))
+          WalkResult::skip();
+        if (isa<omp::WorkshareLoopWrapperOp>(op))
+          WalkResult::interrupt();
+        WalkResult::advance();
+      })
       .wasInterrupted();
 }
 
@@ -340,7 +383,8 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 ///
 /// becomes
 ///
-/// omp.single {
+/// %tmp = fir.alloca
+/// omp.single copyprivate(%tmp) {
 ///   %a = fir.allocmem
 ///   fir.store %a %tmp
 /// }
@@ -352,16 +396,15 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
 /// }
 ///
 /// Note that we allocate temporary memory for values in omp.single's which need
-/// to be accessed in all threads in the closest omp.parallel
+/// to be accessed by all threads and broadcast them using single's copyprivate
 LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) {
   Location loc = wsOp->getLoc();
   IRMapping rootMapping;
 
   OpBuilder rootBuilder(wsOp);
 
-  // TODO We need something like an scf.execute here, but that is not registered
-  // so using omp.workshare as a placeholder. We need this op as our
-  // parallelizeRegion works on regions and not blocks.
+  // This operation is just a placeholder which will be erased later. We need it
+  // because our `parallelizeRegion` function works on regions and not blocks.
   omp::WorkshareOp newOp =
       rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands());
   if (!wsOp.getNowait())

>From 82bd78d930a7c167013865de27c68b53247de156 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 19 Aug 2024 16:53:22 +0900
Subject: [PATCH 19/19] Fix tests

---
 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp |  6 +--
 .../Transforms/OpenMP/lower-workshare2.mlir   |  2 +
 .../Transforms/OpenMP/lower-workshare3.mlir   |  2 +-
 .../Transforms/OpenMP/lower-workshare4.mlir   |  3 ++
 .../Transforms/OpenMP/lower-workshare6.mlir   | 51 +++++++++++++++++++
 5 files changed, 60 insertions(+), 4 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workshare6.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
index e921b80d0c571e..9557dd200cacee 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -123,10 +123,10 @@ static bool mustParallelizeOp(Operation *op) {
         //
         // Therefore, we skip if we encounter a nested omp.workshare.
         if (isa<omp::WorkshareOp>(op))
-          WalkResult::skip();
+          return WalkResult::skip();
         if (isa<omp::WorkshareLoopWrapperOp>(op))
-          WalkResult::interrupt();
-        WalkResult::advance();
+          return WalkResult::interrupt();
+        return WalkResult::advance();
       })
       .wasInterrupted();
 }
diff --git a/flang/test/Transforms/OpenMP/lower-workshare2.mlir b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
index 325a40d4184453..940662e0bdccc2 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare2.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
@@ -1,5 +1,7 @@
 // RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
+// Check that we correctly handle nowait
+
 // CHECK-LABEL:   func.func @nonowait
 func.func @nonowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
   // CHECK: omp.barrier
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
index afb41d95e7198e..09217757512881 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -1,7 +1,7 @@
 // RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
 
-// tests if the correct values are stored
+// Check if we store the correct values
 
 func.func @wsfunc() {
   omp.parallel {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
index 0a70007a9e78dd..44f68cd2ca3654 100644
--- a/flang/test/Transforms/OpenMP/lower-workshare4.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -1,5 +1,8 @@
 // RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
 
+// Check that we cleanup unused pure operations from either the parallel or
+// single regions
+
 func.func @wsfunc() {
   %a = fir.alloca i32
   omp.parallel {
diff --git a/flang/test/Transforms/OpenMP/lower-workshare6.mlir b/flang/test/Transforms/OpenMP/lower-workshare6.mlir
new file mode 100644
index 00000000000000..b66f00a47c114f
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare6.mlir
@@ -0,0 +1,51 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// Checks that the omp.workshare_loop_wrapper binds to the correct omp.workshare
+
+func.func @wsfunc() {
+  %c1 = arith.constant 1 : index
+  %c42 = arith.constant 42 : index
+  omp.parallel {
+    omp.workshare nowait {
+      omp.parallel {
+        omp.workshare nowait {
+          omp.workshare_loop_wrapper {
+            omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+              "test.test2"() : () -> ()
+              omp.yield
+            }
+            omp.terminator
+          }
+          omp.terminator
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @wsfunc() {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 42 : index
+// CHECK:           omp.parallel {
+// CHECK:             omp.single nowait {
+// CHECK:               omp.parallel {
+// CHECK:                 omp.wsloop nowait {
+// CHECK:                   omp.loop_nest (%[[VAL_2:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_1]]) inclusive step (%[[VAL_0]]) {
+// CHECK:                     "test.test2"() : () -> ()
+// CHECK:                     omp.yield
+// CHECK:                   }
+// CHECK:                   omp.terminator
+// CHECK:                 }
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+



More information about the llvm-branch-commits mailing list