[llvm-branch-commits] [flang] [flang] Lower omp.workshare to other omp constructs (PR #101446)

Tom Eccles via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Aug 22 03:05:07 PDT 2024


================
@@ -0,0 +1,446 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering of omp.workshare to other omp constructs.
+//
+// This pass is tasked with parallelizing the loops nested in
+// workshare.loop_wrapper while both the Fortran to mlir lowering and the hlfir
+// to fir lowering pipelines are responsible for emitting the
+// workshare.loop_wrapper ops where appropriate according to the
+// `shouldUseWorkshareLowering` function.
+//
+//===----------------------------------------------------------------------===//
+
+#include <flang/Optimizer/Builder/FIRBuilder.h>
+#include <flang/Optimizer/Dialect/FIROps.h>
+#include <flang/Optimizer/Dialect/FIRType.h>
+#include <flang/Optimizer/HLFIR/HLFIROps.h>
+#include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/BreadthFirstIterator.h>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/ADT/SmallVectorExtras.h>
+#include <llvm/ADT/iterator_range.h>
+#include <llvm/Support/ErrorHandling.h>
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/IR/Visitors.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKSHARE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workshare"
+
+using namespace mlir;
+
+namespace flangomp {
+
+// Checks for nesting pattern below as we need to avoid sharing the work of
+// statements which are nested in some constructs such as omp.critical or
+// another omp.parallel.
+//
+// omp.workshare { // `wsOp`
+//   ...
+//     omp.T { // `parent`
+//       ...
+//         `op`
+//
+template <typename T>
+static bool isNestedIn(omp::WorkshareOp wsOp, Operation *op) {
+  T parent = op->getParentOfType<T>();
+  if (!parent)
+    return false;
+  return wsOp->isProperAncestor(parent);
+}
+
+bool shouldUseWorkshareLowering(Operation *op) {
+  auto parentWorkshare = op->getParentOfType<omp::WorkshareOp>();
+
+  if (!parentWorkshare)
+    return false;
+
+  if (isNestedIn<omp::CriticalOp>(parentWorkshare, op))
+    return false;
+
+  // 2.8.3  workshare Construct
+  // For a parallel construct, the construct is a unit of work with respect to
+  // the workshare construct. The statements contained in the parallel construct
+  // are executed by a new thread team.
+  if (isNestedIn<omp::ParallelOp>(parentWorkshare, op))
+    return false;
+
+  // 2.8.2  single Construct
+  // Binding The binding thread set for a single region is the current team. A
+  // single region binds to the innermost enclosing parallel region.
+  // Description Only one of the encountering threads will execute the
+  // structured block associated with the single construct.
+  if (isNestedIn<omp::SingleOp>(parentWorkshare, op))
+    return false;
+
+  return true;
+}
+
+} // namespace flangomp
+
+namespace {
+
+struct SingleRegion {
+  Block::iterator begin, end;
+};
+
+static bool mustParallelizeOp(Operation *op) {
+  return op
+      ->walk([&](Operation *nested) {
+        // We need to be careful not to pick up workshare.loop_wrapper in nested
+        // omp.parallel{omp.workshare} regions, i.e. make sure that `nested`
+        // binds to the workshare region we are currently handling.
+        //
+        // For example:
+        //
+        // omp.parallel {
+        //   omp.workshare { // currently handling this
+        //     omp.parallel {
+        //       omp.workshare { // nested workshare
+        //         omp.workshare.loop_wrapper {}
+        //
+        // Therefore, we skip if we encounter a nested omp.workshare.
+        if (isa<omp::WorkshareOp>(op))
+          return WalkResult::skip();
+        if (isa<omp::WorkshareLoopWrapperOp>(op))
+          return WalkResult::interrupt();
+        return WalkResult::advance();
+      })
+      .wasInterrupted();
+}
+
+static bool isSafeToParallelize(Operation *op) {
+  return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
+         isMemoryEffectFree(op);
+}
+
+/// Simple shallow copies suffice for our purposes in this pass, so we implement
+/// this simpler alternative to the full fledged `createCopyFunc` in the
+/// frontend
+static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
+                                         fir::FirOpBuilder builder) {
+  mlir::ModuleOp module = builder.getModule();
+  auto rt = cast<fir::ReferenceType>(varType);
+  mlir::Type eleTy = rt.getEleTy();
+  std::string copyFuncName =
+      fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+
+  if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
+    return decl;
+  // create function
+  mlir::OpBuilder::InsertionGuard guard(builder);
+  mlir::OpBuilder modBuilder(module.getBodyRegion());
+  llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
+  auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
+  mlir::func::FuncOp funcOp =
+      modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
+  funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
+  builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
+                      {loc, loc});
+  builder.setInsertionPointToStart(&funcOp.getRegion().back());
+
+  Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(0));
+  builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(1));
+
+  builder.create<mlir::func::ReturnOp>(loc);
+  return funcOp;
+}
+
+static bool isUserOutsideSR(Operation *user, Operation *parentOp,
+                            SingleRegion sr) {
+  while (user->getParentOp() != parentOp)
+    user = user->getParentOp();
+  return sr.begin->getBlock() != user->getBlock() ||
+         !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
+}
+
+static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
+  Block *srBlock = sr.begin->getBlock();
+  Operation *parentOp = srBlock->getParentOp();
+
+  for (auto &use : v.getUses()) {
+    Operation *user = use.getOwner();
+    if (isUserOutsideSR(user, parentOp, sr))
+      return true;
+
+    // Results of nested users cannot be used outside of the SR
+    if (user->getBlock() != srBlock)
+      continue;
+
+    // A non-safe to parallelize operation will be handled separately
+    if (!isSafeToParallelize(user))
+      continue;
+
+    for (auto res : user->getResults())
+      if (isTransitivelyUsedOutside(res, sr))
+        return true;
+  }
+  return false;
+}
+
+/// We clone pure operations in both the parallel and single blocks. this
+/// functions cleans them up if they end up with no uses
+static void cleanupBlock(Block *block) {
+  for (Operation &op : llvm::make_early_inc_range(
+           llvm::make_range(block->rbegin(), block->rend())))
+    if (isOpTriviallyDead(&op))
+      op.erase();
+}
+
+static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
+                              IRMapping &rootMapping, Location loc,
+                              mlir::DominanceInfo &di) {
+  OpBuilder rootBuilder(sourceRegion.getContext());
+  ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
+  OpBuilder copyFuncBuilder(m.getBodyRegion());
+  fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
+
+  auto mapReloadedValue =
+      [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
+          OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
+    if (auto reloaded = rootMapping.lookupOrNull(v))
+      return nullptr;
+    Type ty = v.getType();
+    Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
+    singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+    Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
+    rootMapping.map(v, reloaded);
+    return alloc;
+  };
+
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
+                          OpBuilder singleBuilder,
+                          OpBuilder parallelBuilder) -> SmallVector<Value> {
+    IRMapping singleMapping = rootMapping;
+    SmallVector<Value> copyPrivate;
+
+    for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
+      if (isSafeToParallelize(&op)) {
+        singleBuilder.clone(op, singleMapping);
+        parallelBuilder.clone(op, rootMapping);
+      } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
+        auto hoisted =
+            cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
+        rootMapping.map(&*alloca, &*hoisted);
+        rootMapping.map(alloca.getResult(), hoisted.getResult());
+        copyPrivate.push_back(hoisted);
+      } else {
+        singleBuilder.clone(op, singleMapping);
+        // Prepare reloaded values for results of operations that cannot be
+        // safely parallelized and which are used after the region `sr`
+        for (auto res : op.getResults()) {
+          if (isTransitivelyUsedOutside(res, sr)) {
+            auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
+                                          parallelBuilder, singleMapping);
+            if (alloc)
+              copyPrivate.push_back(alloc);
+          }
+        }
+      }
+    }
+    singleBuilder.create<omp::TerminatorOp>(loc);
+    return copyPrivate;
+  };
+
+  for (Block &block : sourceRegion) {
+    Block *targetBlock = rootBuilder.createBlock(
+        &targetRegion, {}, block.getArgumentTypes(),
+        llvm::map_to_vector(block.getArguments(),
+                            [](BlockArgument arg) { return arg.getLoc(); }));
+    rootMapping.map(&block, targetBlock);
+    rootMapping.map(block.getArguments(), targetBlock->getArguments());
+  }
+
+  auto handleOneBlock = [&](Block &block) {
+    Block &targetBlock = *rootMapping.lookup(&block);
+    rootBuilder.setInsertionPointToStart(&targetBlock);
+    Operation *terminator = block.getTerminator();
+    SmallVector<std::variant<SingleRegion, Operation *>> regions;
+
+    auto it = block.begin();
+    auto getOneRegion = [&]() {
+      if (&*it == terminator)
+        return false;
+      if (mustParallelizeOp(&*it)) {
+        regions.push_back(&*it);
+        it++;
+        return true;
+      }
+      SingleRegion sr;
+      sr.begin = it;
+      while (&*it != terminator && !mustParallelizeOp(&*it))
+        it++;
+      sr.end = it;
+      assert(sr.begin != sr.end);
+      regions.push_back(sr);
+      return true;
+    };
+    while (getOneRegion())
+      ;
+
+    for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
+      bool isLast = i + 1 == regions.size();
+      if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+        OpBuilder singleBuilder(sourceRegion.getContext());
+        Block *singleBlock = new Block();
----------------
tblah wrote:

I'm not sure about this. Everywhere else in flang we use `OpBuilder::createBlock`.

I think the only difference would be that builder listeners would be notified, and we don't have any here. But this could be surprising when the code is changed. 

But I can see from the implementation of `createBlock` that this isn't incorrect so it is okay with me if you have a good reason?

https://github.com/llvm/llvm-project/pull/101446


More information about the llvm-branch-commits mailing list