[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #94190)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 10 18:16:00 PDT 2024
================
@@ -1646,81 +1641,58 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
- Operation *oldLoopOp = nullptr;
- SmallVector<Value> newOuts;
- Block *oldLoopBody = nullptr;
- unsigned initSize = 0;
- unsigned rank = 1;
+ // There are two possible cases regarding `oldLoopOp` here:
+ // 1. single `scf.forall` or `scf.for`.
+ // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
+ // top-level loop is the outer-most one of these nested loops.
+ LoopLikeOpInterface innerMostLoop =
+ candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
+ SmallVector<LoopLikeOpInterface> nestedLoops;
if (isInsertSliceOp) {
- auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
- oldLoopOp = forOp;
- llvm::append_range(newOuts, forOp.getInits());
- oldLoopBody = forOp.getBody();
- initSize = forOp.getInits().size();
+ nestedLoops = getPerfectlyOuterNestedLoops(innerMostLoop);
} else {
- auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
- oldLoopOp = forallOp;
- llvm::append_range(newOuts, forallOp.getOutputs());
- oldLoopBody = forallOp.getBody();
- initSize = forallOp.getOutputs().size();
- rank = forallOp.getRank();
+ nestedLoops = {innerMostLoop};
}
- if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ LoopLikeOpInterface outerMostLoop = nestedLoops.front();
+
+ if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
return rewriter.notifyMatchFailure(
- oldLoopOp, "containing loop op should either yield just one value or "
- "have the consumer op as its first user");
+ outerMostLoop,
+ "containing loop op should either yield just one value or "
+ "have the consumer op as its first user");
}
OpBuilder::InsertionGuard g(rewriter);
// 2. Check consumer is not using scf loop's output as init.
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp)
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer op is not DPS operation");
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
- newOuts.append(dpsInits);
-
- Location loc = oldLoopOp->getLoc();
+ SmallVector<Value> newInits = dpsInits;
- // 3. Create new scf loop op.
- rewriter.setInsertionPoint(consumerOp);
- Operation *newLoopOp = nullptr;
- Block *newLoopBody = nullptr;
- if (isInsertSliceOp) {
- auto forOp = cast<scf::ForOp>(oldLoopOp);
- auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
- forOp.getUpperBound(),
- forOp.getStep(), newOuts);
- newLoopOp = newForOp;
- newLoopBody = newForOp.getBody();
- } else {
- auto forallOp = cast<scf::ForallOp>(oldLoopOp);
- auto newForallOp = rewriter.create<scf::ForallOp>(
- loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
- forallOp.getMixedStep(), newOuts, forallOp.getMapping());
- newLoopOp = newForallOp;
- rewriter.eraseOp(newForallOp.getTerminator());
- newLoopBody = newForallOp.getBody();
- }
+ Location loc = outerMostLoop->getLoc();
- // 4. Move the loop body to the new op.
- unsigned oldNumArguments = oldLoopBody->getNumArguments();
- rewriter.mergeBlocks(oldLoopBody, newLoopBody,
- newLoopBody->getArguments().take_front(oldNumArguments));
+ // 3. Move the whole loop structure right before consumer Op, the dominance
+ // should be already ensured by `checkAssumptionForLoop`.
+ outerMostLoop->moveBefore(consumerOp);
----------------
Yun-Fly wrote:
Thanks for this reminder! Changed.
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list