[Mlir-commits] [mlir] [SCF][Transform] Add support for scf.for in LoopFuseSibling op (PR #81495)

Rolf Morel llvmlistbot at llvm.org
Mon Mar 25 04:43:38 PDT 2024


================
@@ -970,3 +970,69 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
 
   return fusedLoop;
 }
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+                                                scf::ForOp source,
+                                                RewriterBase &rewriter) {
+  // Create fused init_args.
+  auto targetInitArgs = target.getInitArgs();
+  auto sourceInitArgs = source.getInitArgs();
+  SmallVector<Value> fusedInitArgs;
+  fusedInitArgs.reserve(targetInitArgs.size() + sourceInitArgs.size());
+  fusedInitArgs.append(sourceInitArgs.begin(), sourceInitArgs.end());
+  fusedInitArgs.append(targetInitArgs.begin(), targetInitArgs.end());
+
+  // Create a new scf::for op after the source loop.
+  rewriter.setInsertionPointAfter(source);
+  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+      source.getStep(), fusedInitArgs);
+
+  SmallVector<Value> yieldResults;
+
+  // First merge source loop into the new (fused) for loop and then target loop.
+  rewriter.setInsertionPointToEnd(fusedLoop.getBody());
+  for (auto loopAndInitArgsBegin :
+       {std::pair(source, (unsigned int)0),
+        std::pair(target, source.getNumRegionIterArgs())}) {
+    auto origLoop = loopAndInitArgsBegin.first;
+    IRMapping mapping;
+
+    mapping.map(origLoop.getInductionVar(), fusedLoop.getInductionVar());
+    for (size_t i = 0; i < origLoop.getNumRegionIterArgs(); ++i) {
+      mapping.map(
+          origLoop.getRegionIterArgs()[i],
+          fusedLoop.getRegionIterArgs()[loopAndInitArgsBegin.second + i]);
+    }
+
+    for (Operation &op : origLoop.getBody()->getOperations()) {
+      rewriter.clone(op, mapping);
+    }
+
+    if (origLoop.getNumResults() > 0) {
+      scf::YieldOp yieldFromOrigLoop =
+          cast<scf::YieldOp>(fusedLoop.getBody()->getTerminator());
+      yieldResults.append(yieldFromOrigLoop.getOperands().begin(),
+                          yieldFromOrigLoop.getOperands().end());
+      rewriter.eraseOp(yieldFromOrigLoop);
+    }
+  }
+
+  // Construct combined YieldOp
+  rewriter.setInsertionPointToEnd(fusedLoop.getBody());
+  rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+  // Replace all uses of the old loops with the fused loop.
+  unsigned numSourceOuts = source.getNumResults();
+  rewriter.replaceAllUsesWith(source.getResults(),
+                              fusedLoop.getResults().slice(0, numSourceOuts));
----------------
rolfmorel wrote:

Done.

https://github.com/llvm/llvm-project/pull/81495


More information about the Mlir-commits mailing list