[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Sep 8 19:53:03 PDT 2024
================
@@ -1654,52 +1729,99 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
if (isInsertSliceOp) {
auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
oldLoopOp = forOp;
- llvm::append_range(newOuts, forOp.getInits());
- oldLoopBody = forOp.getBody();
initSize = forOp.getInits().size();
} else {
auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
oldLoopOp = forallOp;
- llvm::append_range(newOuts, forallOp.getOutputs());
- oldLoopBody = forallOp.getBody();
initSize = forallOp.getOutputs().size();
rank = forallOp.getRank();
}
- if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ // There are two possible cases regarding `oldLoopOp` here:
+ // 1. single `scf.forall` or `scf.for`.
+ // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
+ // top-level loop is the outer-most one of these nested loops.
+ Operation *oldTopLevelLoop = oldLoopOp;
+ SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
+ if (isInsertSliceOp) {
+ oldNestedForOps =
+ getOuterNestLoopsWhile(cast<LoopLikeOpInterface>(oldTopLevelLoop),
+ isForOpYieldResultOfInnerLoop);
+ oldTopLevelLoop = oldNestedForOps.front();
+ }
+
+ if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp))) {
return rewriter.notifyMatchFailure(
- oldLoopOp, "containing loop op should either yield just one value or "
- "have the consumer op as its first user");
+ oldTopLevelLoop,
+ "containing loop op should either yield just one value or "
+ "have the consumer op as its first user");
}
OpBuilder::InsertionGuard g(rewriter);
// 2. Check consumer is not using scf loop's output as init.
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp)
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer op is not DPS operation");
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
- newOuts.append(dpsInits);
+ SmallVector<Value> newInitAppend = dpsInits;
Location loc = oldLoopOp->getLoc();
// 3. Create new scf loop op.
rewriter.setInsertionPoint(consumerOp);
+
+ // 3.a Create new outer scf loops with new Inits only if nested `scf.for`
+ // case was found.
+ bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size() > 1;
----------------
Yun-Fly wrote:
> But an even easier solution here maybe to move the outermost loop just before the consumer?
Is that exactly current solution? Or could you details the difference with your expectation?
> I think it might be easier to merge this with the inner loop handling?
The original intention here is that DO NOT mix(or merge) new feature with existed code status to bring additional confusion to review. If you really prefer merging upcoming and current handling, please let me know :)
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list