[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 16 09:32:10 PDT 2024
================
@@ -1418,52 +1560,93 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
if (isInsertSliceOp) {
auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
oldLoopOp = forOp;
- llvm::append_range(newOuts, forOp.getInits());
- oldLoopBody = forOp.getBody();
initSize = forOp.getInits().size();
} else {
auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
oldLoopOp = forallOp;
- llvm::append_range(newOuts, forallOp.getOutputs());
- oldLoopBody = forallOp.getBody();
initSize = forallOp.getOutputs().size();
rank = forallOp.getRank();
}
- if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ Operation *oldTopLevelLoop = oldLoopOp;
+ SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
+ if (isInsertSliceOp) {
+ oldNestedForOps =
+ getOuterNestLoopsWhile(cast<LoopLikeOpInterface>(oldTopLevelLoop),
+ isForOpYieldResultOfInnerLoop);
+ oldTopLevelLoop = oldNestedForOps.front();
+ }
+ // 2.a Check assumption for loop and find suitable insertPoint that loop
+ // structure would be cloned right before.
+ Operation *insertPointBefore = nullptr;
+ if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp,
+ &insertPointBefore))) {
return rewriter.notifyMatchFailure(
- oldLoopOp, "containing loop op should either yield just one value or "
- "have the consumer op as its first user");
+ oldTopLevelLoop, "containing loop op does not satisfy the assumption "
+ "and no suitable insertPoint is found");
}
OpBuilder::InsertionGuard g(rewriter);
- // 2. Check consumer is not using scf loop's output as init.
+ // 2.b Check consumer is not using scf loop's output as init.
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
- newOuts.append(dpsInits);
+ SmallVector<Value> newInitAppend = dpsInits;
Location loc = oldLoopOp->getLoc();
// 3. Create new scf loop op.
- rewriter.setInsertionPoint(consumerOp);
+ rewriter.setInsertionPoint(insertPointBefore);
+
+ // 3.a Create new outer scf loops if necessary
----------------
MaheshRavishankar wrote:
I am missing why we need to create new scf.fors...
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list