[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)

Ivan Butygin llvmlistbot at llvm.org
Thu Jun 27 11:30:47 PDT 2024


================
@@ -174,61 +158,10 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
                      mayAlias))
     return;
 
-  DominanceInfo dom;
-  // We are fusing first loop into second, make sure there are no users of the
-  // first loop results between loops.
-  for (Operation *user : firstPloop->getUsers())
-    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
-      return;
-
-  ValueRange inits1 = firstPloop.getInitVals();
-  ValueRange inits2 = secondPloop.getInitVals();
-
-  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
-  newInitVars.append(inits2.begin(), inits2.end());
-
-  IRRewriter b(builder);
-  b.setInsertionPoint(secondPloop);
-  auto newSecondPloop = b.create<ParallelOp>(
-      secondPloop.getLoc(), secondPloop.getLowerBound(),
-      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
-
-  Block *newBlock = newSecondPloop.getBody();
-  auto term1 = cast<ReduceOp>(block1->getTerminator());
-  auto term2 = cast<ReduceOp>(block2->getTerminator());
-
-  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
-                      newBlock->getArguments());
-  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
-                      newBlock->getArguments());
-
-  ValueRange results = newSecondPloop.getResults();
-  if (!results.empty()) {
-    b.setInsertionPointToEnd(newBlock);
-
-    ValueRange reduceArgs1 = term1.getOperands();
-    ValueRange reduceArgs2 = term2.getOperands();
-    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
-    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
-
-    auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
-
-    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
-             term1.getReductions(), term2.getReductions()))) {
-      Block &oldRedBlock = reg.front();
-      Block &newRedBlock = newReduceOp.getReductions()[i].front();
-      b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
-                          newRedBlock.getArguments());
-    }
-
-    firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
-    secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
-  }
-  term1->erase();
-  term2->erase();
-  firstPloop.erase();
-  secondPloop.erase();
-  secondPloop = newSecondPloop;
+  IRRewriter rewriter(builder);
+  secondPloop = mlir::fuseIndependentSiblingParallelLoops(
+      firstPloop, secondPloop, rewriter);
+  ;
----------------
Hardcode84 wrote:

nit: extra `;`

https://github.com/llvm/llvm-project/pull/94391


More information about the Mlir-commits mailing list