[Mlir-commits] [mlir] [mlir][scf] Add reductions support to `scf.parallel` fusion (PR #75955)

Ivan Butygin llvmlistbot at llvm.org
Fri Jan 5 06:08:45 PST 2024


================
@@ -131,29 +131,83 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
 }
 
 /// Prepends operations of firstPloop's body into secondPloop's body.
-static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
-                        OpBuilder b,
+/// Updates secondPloop with new loop.
+static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
+                        OpBuilder builder,
                         llvm::function_ref<bool(Value, Value)> mayAlias) {
+  Block *block1 = firstPloop.getBody();
+  Block *block2 = secondPloop.getBody();
   IRMapping firstToSecondPloopIndices;
-  firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
-                                secondPloop.getBody()->getArguments());
+  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
 
   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
                      mayAlias))
     return;
 
-  b.setInsertionPointToStart(secondPloop.getBody());
-  for (auto &op : firstPloop.getBody()->without_terminator())
-    b.clone(op, firstToSecondPloopIndices);
+  DominanceInfo dom;
+  for (Operation *user : firstPloop->getUsers())
+    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+      return;
+
+  ValueRange inits1 = firstPloop.getInitVals();
+  ValueRange inits2 = secondPloop.getInitVals();
+
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+
+  IRRewriter b(builder);
+  b.setInsertionPoint(secondPloop);
+  auto newSecondPloop = b.create<ParallelOp>(
+      secondPloop.getLoc(), secondPloop.getLowerBound(),
+      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+  Block *newBlock = newSecondPloop.getBody();
+  auto term1 = cast<ReduceOp>(block1->getTerminator());
+  auto term2 = cast<ReduceOp>(block2->getTerminator());
+
+  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
----------------
Hardcode84 wrote:

Each reduction corresponds to the parent `scf.parallel` op result value, so if the first loop had any reductions, those results must be part of the fused parent op, changing total results count.

https://github.com/llvm/llvm-project/pull/75955


More information about the Mlir-commits mailing list