[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 16 09:32:10 PDT 2024
================
@@ -1577,28 +1760,159 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
newForallOp.getBody()->getArguments().drop_front(rank + initSize));
}
- // 12. Replace the result of scf loop and consumer op with new loop's results.
+ // 12. Restore outer loops from inner to outer
+ if (isNestedForOps) {
+ newNestedForOps.push_back(cast<scf::ForOp>(newLoopOp));
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(MutableArrayRef(newNestedForOps).drop_back(),
+ MutableArrayRef(newNestedForOps).drop_front())) {
+ auto forOp = cast<scf::ForOp>(outerLoop);
+ auto outerLoopYield =
+ cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+ SmallVector<Value> newYields =
+ llvm::to_vector(outerLoopYield.getOperands());
+ ValueRange additionalYields =
+ innerLoop->getResults().take_back(newInitAppend.size());
+ newYields.append(additionalYields.begin(), additionalYields.end());
+ rewriter.setInsertionPoint(outerLoopYield);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
+ }
+ }
+
+ // 13. Replace the result of scf loop and consumer op with new loop's results.
for (auto &&[oldResult, newResult] :
llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
+ Operation *newTopLevelLoop =
+ isNestedForOps ? newNestedForOps.front() : newLoopOp;
for (auto &&[oldResult, newResult] :
llvm::zip(consumerOp->getResults(),
- newLoopOp->getResults().drop_front(initSize))) {
+ newTopLevelLoop->getResults().drop_front(initSize))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
- // 13. Need to erase the old scf loop and the cloned consumer op.
+ // 14. Need to erase the old scf loop and the cloned consumer op.
rewriter.eraseOp(oldLoopOp);
rewriter.eraseOp(clonedConsumerOp);
+ // 15. Need to erase the cloned insertSliceOp and unused extractSliceOp in
+ // avoid of complex domination analysis
+ assert(clonedInsertSliceOp->hasOneUse());
+ auto unUsedExtractOp =
+ cast<tensor::ExtractSliceOp>((*clonedInsertSliceOp->getUsers().begin()));
+ rewriter.eraseOp(unUsedExtractOp);
+ rewriter.eraseOp(clonedInsertSliceOp);
+
return scf::SCFFuseConsumerOfSliceResult{
consumerOpOperand,
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
tileAndFuseResult->tiledOps};
}
+/// Get the real consumers from candidate InsertSliceOp. E.g
+///
+/// ```
+/// %1 = scf.for
+/// %2 = scf.for
+/// %3 = scf.for
+/// ...
+/// %4 = insert
+/// yield %4
+/// %5 = insert %3
+/// yield %5
+/// yield %2
+/// %6 = consumerOp ins(%1)
+/// ```
+///
+/// @param candidateSliceOp: %4 = insert
+/// @param forwardSlice: in-out parameter populated by forward insertSliceOps
+/// @return OpOperand consumers: %6 = consumerOp ins(%1)
+static FailureOr<SmallVector<OpOperand *>> getRealConsumersFromInsertSliceOp(
----------------
MaheshRavishankar wrote:
This method is too much for me to follow and seems like a generalization past what we should do in one step. Can we defer the complication this adding to a follow up?
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list