[Mlir-commits] [mlir] [SCF][Transform] Add support for scf.for in LoopFuseSibling op (PR #81495)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Tue Feb 13 03:56:53 PST 2024
================
@@ -970,3 +970,69 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
return fusedLoop;
}
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+ scf::ForOp source,
+ RewriterBase &rewriter) {
+ // Create fused init_args.
+ auto targetInitArgs = target.getInitArgs();
+ auto sourceInitArgs = source.getInitArgs();
+ SmallVector<Value> fusedInitArgs;
+ fusedInitArgs.reserve(targetInitArgs.size() + sourceInitArgs.size());
+ fusedInitArgs.append(sourceInitArgs.begin(), sourceInitArgs.end());
+ fusedInitArgs.append(targetInitArgs.begin(), targetInitArgs.end());
+
+ // Create a new scf::for op after the source loop.
+ rewriter.setInsertionPointAfter(source);
+ scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+ source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+ source.getStep(), fusedInitArgs);
+
+ SmallVector<Value> yieldResults;
+
+ // First merge source loop into the new (fused) for loop and then target loop.
+ rewriter.setInsertionPointToEnd(fusedLoop.getBody());
+ for (auto loopAndInitArgsBegin :
+ {std::pair(source, (unsigned int)0),
+ std::pair(target, source.getNumRegionIterArgs())}) {
+ auto origLoop = loopAndInitArgsBegin.first;
+ IRMapping mapping;
+
+ mapping.map(origLoop.getInductionVar(), fusedLoop.getInductionVar());
+ for (size_t i = 0; i < origLoop.getNumRegionIterArgs(); ++i) {
+ mapping.map(
+ origLoop.getRegionIterArgs()[i],
+ fusedLoop.getRegionIterArgs()[loopAndInitArgsBegin.second + i]);
+ }
+
+ for (Operation &op : origLoop.getBody()->getOperations()) {
+ rewriter.clone(op, mapping);
+ }
+
+ if (origLoop.getNumResults() > 0) {
+ scf::YieldOp yieldFromOrigLoop =
+ cast<scf::YieldOp>(fusedLoop.getBody()->getTerminator());
+ yieldResults.append(yieldFromOrigLoop.getOperands().begin(),
+ yieldFromOrigLoop.getOperands().end());
+ rewriter.eraseOp(yieldFromOrigLoop);
+ }
+ }
+
+ // Construct combined YieldOp
+ rewriter.setInsertionPointToEnd(fusedLoop.getBody());
+ rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+ // Replace all uses of the old loops with the fused loop.
+ unsigned numSourceOuts = source.getNumResults();
+ rewriter.replaceAllUsesWith(source.getResults(),
+ fusedLoop.getResults().slice(0, numSourceOuts));
----------------
ftynse wrote:
Nit: consider `take_front`/`drop_front` as a more readable alternatives to `slice`.
https://github.com/llvm/llvm-project/pull/81495
More information about the Mlir-commits
mailing list