[llvm-branch-commits] [mlir] [MLIR][SCF] Update scf.parallel lowering to OpenMP (3/5) (PR #89212)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Apr 19 05:54:40 PDT 2024


================
@@ -461,18 +461,51 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
       // Replace the loop.
       {
         OpBuilder::InsertionGuard allocaGuard(rewriter);
-        auto loop = rewriter.create<omp::WsloopOp>(
+        // Create worksharing loop wrapper.
+        auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
+        if (!reductionVariables.empty()) {
+          wsloopOp.setReductionsAttr(
+              ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
+          wsloopOp.getReductionVarsMutable().append(reductionVariables);
+        }
+        rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
+
+        // The wrapper's entry block arguments will define the reduction
+        // variables.
+        llvm::SmallVector<mlir::Type> reductionTypes;
+        reductionTypes.reserve(reductionVariables.size());
+        llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
+                        [](mlir::Value v) { return v.getType(); });
+        rewriter.createBlock(
+            &wsloopOp.getRegion(), {}, reductionTypes,
+            llvm::SmallVector<mlir::Location>(reductionVariables.size(),
+                                              parallelOp.getLoc()));
+
+        rewriter.setInsertionPoint(
+            rewriter.create<omp::TerminatorOp>(parallelOp.getLoc()));
+
+        // Create loop nest and populate region with contents of scf.parallel.
+        auto loopOp = rewriter.create<omp::LoopNestOp>(
             parallelOp.getLoc(), parallelOp.getLowerBound(),
             parallelOp.getUpperBound(), parallelOp.getStep());
-        rewriter.create<omp::TerminatorOp>(loc);
 
-        rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(),
-                                    loop.getRegion().begin());
+        rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
+                                    loopOp.getRegion().begin());
+
+        // Remove reduction-related block arguments from omp.loop_nest and
+        // redirect uses to the corresponding omp.wsloop block argument.
+        mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
+        unsigned numLoops = parallelOp.getNumLoops();
+        rewriter.replaceAllUsesWith(
+            loopOpEntryBlock.getArguments().drop_front(numLoops),
+            wsloopOp.getRegion().getArguments());
+        loopOpEntryBlock.eraseArguments(
+            numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
 
-        Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(),
-                                         loop.getRegion().begin()->begin());
+        Block *ops = rewriter.splitBlock(&*loopOp.getRegion().begin(),
+                                         loopOp.getRegion().begin()->begin());
----------------
skatrak wrote:

Thanks for the suggestion, done!

https://github.com/llvm/llvm-project/pull/89212


More information about the llvm-branch-commits mailing list