[llvm-branch-commits] [mlir] [MLIR][SCF] Update scf.parallel lowering to OpenMP (3/5) (PR #89212)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Apr 19 05:54:40 PDT 2024
================
@@ -461,18 +461,51 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Replace the loop.
{
OpBuilder::InsertionGuard allocaGuard(rewriter);
- auto loop = rewriter.create<omp::WsloopOp>(
+ // Create worksharing loop wrapper.
+ auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
+ if (!reductionVariables.empty()) {
+ wsloopOp.setReductionsAttr(
+ ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
+ wsloopOp.getReductionVarsMutable().append(reductionVariables);
+ }
+ rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
+
+ // The wrapper's entry block arguments will define the reduction
+ // variables.
+ llvm::SmallVector<mlir::Type> reductionTypes;
+ reductionTypes.reserve(reductionVariables.size());
+ llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
+ [](mlir::Value v) { return v.getType(); });
+ rewriter.createBlock(
+ &wsloopOp.getRegion(), {}, reductionTypes,
+ llvm::SmallVector<mlir::Location>(reductionVariables.size(),
+ parallelOp.getLoc()));
+
+ rewriter.setInsertionPoint(
+ rewriter.create<omp::TerminatorOp>(parallelOp.getLoc()));
+
+ // Create loop nest and populate region with contents of scf.parallel.
+ auto loopOp = rewriter.create<omp::LoopNestOp>(
parallelOp.getLoc(), parallelOp.getLowerBound(),
parallelOp.getUpperBound(), parallelOp.getStep());
- rewriter.create<omp::TerminatorOp>(loc);
- rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(),
- loop.getRegion().begin());
+ rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
+ loopOp.getRegion().begin());
+
+ // Remove reduction-related block arguments from omp.loop_nest and
+ // redirect uses to the corresponding omp.wsloop block argument.
+ mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
+ unsigned numLoops = parallelOp.getNumLoops();
+ rewriter.replaceAllUsesWith(
+ loopOpEntryBlock.getArguments().drop_front(numLoops),
+ wsloopOp.getRegion().getArguments());
+ loopOpEntryBlock.eraseArguments(
+ numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
- Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(),
- loop.getRegion().begin()->begin());
+ Block *ops = rewriter.splitBlock(&*loopOp.getRegion().begin(),
+ loopOp.getRegion().begin()->begin());
----------------
skatrak wrote:
Thanks for the suggestion, done!
https://github.com/llvm/llvm-project/pull/89212
More information about the llvm-branch-commits
mailing list