[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 16 09:32:10 PDT 2024


================
@@ -1577,28 +1760,159 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         newForallOp.getBody()->getArguments().drop_front(rank + initSize));
   }
 
-  // 12. Replace the result of scf loop and consumer op with new loop's results.
+  // 12.  Restore outer loops from inner to outer
+  if (isNestedForOps) {
+    newNestedForOps.push_back(cast<scf::ForOp>(newLoopOp));
+    for (auto [outerLoop, innerLoop] :
+         llvm::zip_equal(MutableArrayRef(newNestedForOps).drop_back(),
+                         MutableArrayRef(newNestedForOps).drop_front())) {
+      auto forOp = cast<scf::ForOp>(outerLoop);
+      auto outerLoopYield =
+          cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+      SmallVector<Value> newYields =
+          llvm::to_vector(outerLoopYield.getOperands());
+      ValueRange additionalYields =
+          innerLoop->getResults().take_back(newInitAppend.size());
+      newYields.append(additionalYields.begin(), additionalYields.end());
+      rewriter.setInsertionPoint(outerLoopYield);
+      rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
+    }
+  }
+
+  // 13. Replace the result of scf loop and consumer op with new loop's results.
   for (auto &&[oldResult, newResult] :
        llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
+  Operation *newTopLevelLoop =
+      isNestedForOps ? newNestedForOps.front() : newLoopOp;
   for (auto &&[oldResult, newResult] :
        llvm::zip(consumerOp->getResults(),
-                 newLoopOp->getResults().drop_front(initSize))) {
+                 newTopLevelLoop->getResults().drop_front(initSize))) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
-  // 13. Need to erase the old scf loop and the cloned consumer op.
+  // 14. Need to erase the old scf loop and the cloned consumer op.
   rewriter.eraseOp(oldLoopOp);
   rewriter.eraseOp(clonedConsumerOp);
 
+  // 15. Need to erase the cloned insertSliceOp and unused extractSliceOp in
+  // avoid of complex domination analysis
+  assert(clonedInsertSliceOp->hasOneUse());
+  auto unUsedExtractOp =
+      cast<tensor::ExtractSliceOp>((*clonedInsertSliceOp->getUsers().begin()));
+  rewriter.eraseOp(unUsedExtractOp);
+  rewriter.eraseOp(clonedInsertSliceOp);
+
   return scf::SCFFuseConsumerOfSliceResult{
       consumerOpOperand,
       &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
       tileAndFuseResult->tiledOps};
 }
 
+/// Get the real consumers from candidate InsertSliceOp. E.g
+///
+/// ```
+/// %1 = scf.for
+///  %2 = scf.for
+///   %3 = scf.for
+///      ...
+///      %4 = insert
+///      yield %4
+///   %5 = insert %3
+///   yield %5
+///  yield %2
+/// %6 = consumerOp ins(%1)
+/// ```
+///
+/// @param candidateSliceOp: %4 = insert
+/// @param forwardSlice: in-out parameter populated by forward insertSliceOps
+/// @return OpOperand consumers: %6 = consumerOp ins(%1)
+static FailureOr<SmallVector<OpOperand *>> getRealConsumersFromInsertSliceOp(
----------------
MaheshRavishankar wrote:

This method is too much for me to follow and seems like a generalization past what we should do in one step. Can we defer the complication this adding to a follow up?

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list