[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #94190)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 9 18:51:47 PDT 2024
================
@@ -1654,52 +1729,99 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
if (isInsertSliceOp) {
auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
oldLoopOp = forOp;
- llvm::append_range(newOuts, forOp.getInits());
- oldLoopBody = forOp.getBody();
initSize = forOp.getInits().size();
} else {
auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
oldLoopOp = forallOp;
- llvm::append_range(newOuts, forallOp.getOutputs());
- oldLoopBody = forallOp.getBody();
initSize = forallOp.getOutputs().size();
rank = forallOp.getRank();
}
- if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ // There are two possible cases regarding `oldLoopOp` here:
+ // 1. single `scf.forall` or `scf.for`.
+ // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
+ // top-level loop is the outer-most one of these nested loops.
+ Operation *oldTopLevelLoop = oldLoopOp;
+ SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
+ if (isInsertSliceOp) {
+ oldNestedForOps =
+ getOuterNestLoopsWhile(cast<LoopLikeOpInterface>(oldTopLevelLoop),
+ isForOpYieldResultOfInnerLoop);
+ oldTopLevelLoop = oldNestedForOps.front();
+ }
+
+ if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp))) {
return rewriter.notifyMatchFailure(
- oldLoopOp, "containing loop op should either yield just one value or "
- "have the consumer op as its first user");
+ oldTopLevelLoop,
+ "containing loop op should either yield just one value or "
+ "have the consumer op as its first user");
}
OpBuilder::InsertionGuard g(rewriter);
// 2. Check consumer is not using scf loop's output as init.
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp)
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer op is not DPS operation");
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
- newOuts.append(dpsInits);
+ SmallVector<Value> newInitAppend = dpsInits;
Location loc = oldLoopOp->getLoc();
// 3. Create new scf loop op.
rewriter.setInsertionPoint(consumerOp);
+
+ // 3.a Create new outer scf loops with new Inits only if nested `scf.for`
+ // case was found.
+ bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size() > 1;
----------------
MaheshRavishankar wrote:
Ok, Sorry that was my bad... I see that the existing solution is also creating new loops. If you want to add this, then I/you can clean this up after the fact.
Basically I think this would be much easier if we just "moved" the operation before the consumer op and then just use `addInitOperandsToLoopNest` . But I/you can clean this up after the fact as well. I am modiyfing code in this area anyway, so I am fine with taking this up.
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list