[llvm-branch-commits] [flang] [flang] Lower omp.workshare to other omp constructs (PR #101446)
Tom Eccles via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Aug 22 03:05:07 PDT 2024
================
@@ -0,0 +1,446 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering of omp.workshare to other omp constructs.
+//
+// This pass is tasked with parallelizing the loops nested in
+// workshare.loop_wrapper while both the Fortran to mlir lowering and the hlfir
+// to fir lowering pipelines are responsible for emitting the
+// workshare.loop_wrapper ops where appropriate according to the
+// `shouldUseWorkshareLowering` function.
+//
+//===----------------------------------------------------------------------===//
+
+#include <flang/Optimizer/Builder/FIRBuilder.h>
+#include <flang/Optimizer/Dialect/FIROps.h>
+#include <flang/Optimizer/Dialect/FIRType.h>
+#include <flang/Optimizer/HLFIR/HLFIROps.h>
+#include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/BreadthFirstIterator.h>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/ADT/SmallVectorExtras.h>
+#include <llvm/ADT/iterator_range.h>
+#include <llvm/Support/ErrorHandling.h>
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/IR/Visitors.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKSHARE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workshare"
+
+using namespace mlir;
+
+namespace flangomp {
+
+// Checks for nesting pattern below as we need to avoid sharing the work of
+// statements which are nested in some constructs such as omp.critical or
+// another omp.parallel.
+//
+// omp.workshare { // `wsOp`
+// ...
+// omp.T { // `parent`
+// ...
+// `op`
+//
+template <typename T>
+static bool isNestedIn(omp::WorkshareOp wsOp, Operation *op) {
+ T parent = op->getParentOfType<T>();
+ if (!parent)
+ return false;
+ return wsOp->isProperAncestor(parent);
+}
+
+bool shouldUseWorkshareLowering(Operation *op) {
+ auto parentWorkshare = op->getParentOfType<omp::WorkshareOp>();
+
+ if (!parentWorkshare)
+ return false;
+
+ if (isNestedIn<omp::CriticalOp>(parentWorkshare, op))
+ return false;
+
+ // 2.8.3 workshare Construct
+ // For a parallel construct, the construct is a unit of work with respect to
+ // the workshare construct. The statements contained in the parallel construct
+ // are executed by a new thread team.
+ if (isNestedIn<omp::ParallelOp>(parentWorkshare, op))
+ return false;
+
+ // 2.8.2 single Construct
+ // Binding The binding thread set for a single region is the current team. A
+ // single region binds to the innermost enclosing parallel region.
+ // Description Only one of the encountering threads will execute the
+ // structured block associated with the single construct.
+ if (isNestedIn<omp::SingleOp>(parentWorkshare, op))
+ return false;
+
+ return true;
+}
+
+} // namespace flangomp
+
+namespace {
+
+struct SingleRegion {
+ Block::iterator begin, end;
+};
+
+static bool mustParallelizeOp(Operation *op) {
+ return op
+ ->walk([&](Operation *nested) {
+ // We need to be careful not to pick up workshare.loop_wrapper in nested
+ // omp.parallel{omp.workshare} regions, i.e. make sure that `nested`
+ // binds to the workshare region we are currently handling.
+ //
+ // For example:
+ //
+ // omp.parallel {
+ // omp.workshare { // currently handling this
+ // omp.parallel {
+ // omp.workshare { // nested workshare
+ // omp.workshare.loop_wrapper {}
+ //
+ // Therefore, we skip if we encounter a nested omp.workshare.
+ if (isa<omp::WorkshareOp>(op))
+ return WalkResult::skip();
+ if (isa<omp::WorkshareLoopWrapperOp>(op))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ })
+ .wasInterrupted();
+}
+
+static bool isSafeToParallelize(Operation *op) {
+ return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
+ isMemoryEffectFree(op);
+}
+
+/// Simple shallow copies suffice for our purposes in this pass, so we implement
+/// this simpler alternative to the full fledged `createCopyFunc` in the
+/// frontend
+static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
+ fir::FirOpBuilder builder) {
+ mlir::ModuleOp module = builder.getModule();
+ auto rt = cast<fir::ReferenceType>(varType);
+ mlir::Type eleTy = rt.getEleTy();
+ std::string copyFuncName =
+ fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+
+ if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
+ return decl;
+ // create function
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
+ llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
+ auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
+ mlir::func::FuncOp funcOp =
+ modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
+ funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
+ builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
+ {loc, loc});
+ builder.setInsertionPointToStart(&funcOp.getRegion().back());
+
+ Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(0));
+ builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(1));
+
+ builder.create<mlir::func::ReturnOp>(loc);
+ return funcOp;
+}
+
+static bool isUserOutsideSR(Operation *user, Operation *parentOp,
+ SingleRegion sr) {
+ while (user->getParentOp() != parentOp)
+ user = user->getParentOp();
+ return sr.begin->getBlock() != user->getBlock() ||
+ !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
+}
+
+static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
+ Block *srBlock = sr.begin->getBlock();
+ Operation *parentOp = srBlock->getParentOp();
+
+ for (auto &use : v.getUses()) {
+ Operation *user = use.getOwner();
+ if (isUserOutsideSR(user, parentOp, sr))
+ return true;
+
+ // Results of nested users cannot be used outside of the SR
+ if (user->getBlock() != srBlock)
+ continue;
+
+ // A non-safe to parallelize operation will be handled separately
+ if (!isSafeToParallelize(user))
+ continue;
+
+ for (auto res : user->getResults())
+ if (isTransitivelyUsedOutside(res, sr))
+ return true;
+ }
+ return false;
+}
+
+/// We clone pure operations in both the parallel and single blocks. this
+/// functions cleans them up if they end up with no uses
+static void cleanupBlock(Block *block) {
+ for (Operation &op : llvm::make_early_inc_range(
+ llvm::make_range(block->rbegin(), block->rend())))
+ if (isOpTriviallyDead(&op))
+ op.erase();
+}
+
+static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
+ IRMapping &rootMapping, Location loc,
+ mlir::DominanceInfo &di) {
+ OpBuilder rootBuilder(sourceRegion.getContext());
+ ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
+ OpBuilder copyFuncBuilder(m.getBodyRegion());
+ fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
+
+ auto mapReloadedValue =
+ [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
+ OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
+ if (auto reloaded = rootMapping.lookupOrNull(v))
+ return nullptr;
+ Type ty = v.getType();
+ Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
+ singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+ Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
+ rootMapping.map(v, reloaded);
+ return alloc;
+ };
+
+ auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
+ OpBuilder singleBuilder,
+ OpBuilder parallelBuilder) -> SmallVector<Value> {
+ IRMapping singleMapping = rootMapping;
+ SmallVector<Value> copyPrivate;
+
+ for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
+ if (isSafeToParallelize(&op)) {
+ singleBuilder.clone(op, singleMapping);
+ parallelBuilder.clone(op, rootMapping);
+ } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
+ auto hoisted =
+ cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
+ rootMapping.map(&*alloca, &*hoisted);
+ rootMapping.map(alloca.getResult(), hoisted.getResult());
+ copyPrivate.push_back(hoisted);
+ } else {
+ singleBuilder.clone(op, singleMapping);
+ // Prepare reloaded values for results of operations that cannot be
+ // safely parallelized and which are used after the region `sr`
+ for (auto res : op.getResults()) {
+ if (isTransitivelyUsedOutside(res, sr)) {
+ auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
+ parallelBuilder, singleMapping);
+ if (alloc)
+ copyPrivate.push_back(alloc);
+ }
+ }
+ }
+ }
+ singleBuilder.create<omp::TerminatorOp>(loc);
+ return copyPrivate;
+ };
+
+ for (Block &block : sourceRegion) {
+ Block *targetBlock = rootBuilder.createBlock(
+ &targetRegion, {}, block.getArgumentTypes(),
+ llvm::map_to_vector(block.getArguments(),
+ [](BlockArgument arg) { return arg.getLoc(); }));
+ rootMapping.map(&block, targetBlock);
+ rootMapping.map(block.getArguments(), targetBlock->getArguments());
+ }
+
+ auto handleOneBlock = [&](Block &block) {
+ Block &targetBlock = *rootMapping.lookup(&block);
+ rootBuilder.setInsertionPointToStart(&targetBlock);
+ Operation *terminator = block.getTerminator();
+ SmallVector<std::variant<SingleRegion, Operation *>> regions;
+
+ auto it = block.begin();
+ auto getOneRegion = [&]() {
+ if (&*it == terminator)
+ return false;
+ if (mustParallelizeOp(&*it)) {
+ regions.push_back(&*it);
+ it++;
+ return true;
+ }
+ SingleRegion sr;
+ sr.begin = it;
+ while (&*it != terminator && !mustParallelizeOp(&*it))
+ it++;
+ sr.end = it;
+ assert(sr.begin != sr.end);
+ regions.push_back(sr);
+ return true;
+ };
+ while (getOneRegion())
+ ;
+
+ for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
+ bool isLast = i + 1 == regions.size();
+ if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+ OpBuilder singleBuilder(sourceRegion.getContext());
+ Block *singleBlock = new Block();
----------------
tblah wrote:
I'm not sure about this. Everywhere else in flang we use `OpBuilder::createBlock`.
I think the only difference would be that builder listeners would be notified, and we don't have any here. But this could be surprising when the code is changed.
But I can see from the implementation of `createBlock` that this isn't incorrect so it is okay with me if you have a good reason?
https://github.com/llvm/llvm-project/pull/101446
More information about the llvm-branch-commits
mailing list