[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
Abhishek Varma
llvmlistbot at llvm.org
Mon Jun 3 01:57:00 PDT 2024
================
@@ -1316,187 +1501,288 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
- Operation *oldLoopOp = nullptr;
- SmallVector<Value> newOuts;
- Block *oldLoopBody = nullptr;
- unsigned initSize = 0;
- unsigned rank = 1;
- if (isInsertSliceOp) {
- auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
- oldLoopOp = forOp;
- llvm::append_range(newOuts, forOp.getInits());
- oldLoopBody = forOp.getBody();
- initSize = forOp.getInits().size();
- } else {
- auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
- oldLoopOp = forallOp;
- llvm::append_range(newOuts, forallOp.getOutputs());
- oldLoopBody = forallOp.getBody();
- initSize = forallOp.getOutputs().size();
- rank = forallOp.getRank();
- }
-
- if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
- return rewriter.notifyMatchFailure(
- oldLoopOp, "containing loop op should either yield just one value or "
- "have the consumer op as its first user");
- }
-
- OpBuilder::InsertionGuard g(rewriter);
-
- // 2. Check consumer is not using scf loop's output as init.
+ // 3. Check consumer is not using outerMostLoop's output as init.
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
- newOuts.append(dpsInits);
+ ValueRange newInitAppend = dpsInits;
- Location loc = oldLoopOp->getLoc();
+ // 4. reconstruct nested loop from outer to inner.
+ SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList =
+ (*resultAndSliceOpsPair).second;
+ SmallVector<LoopLikeOpInterface> newOuterLoops;
+ SmallVector<SmallVector<OpFoldResult>> allResultOffsets, allResultSizes;
+ // extract slice from newInits of outer-most scf.forall
+ SmallVector<tensor::ExtractSliceOp> newExtractOps;
- // 3. Create new scf loop op.
- rewriter.setInsertionPoint(consumerOp);
- Operation *newLoopOp = nullptr;
+ Block *oldLoopBody = nullptr;
Block *newLoopBody = nullptr;
+ SmallVector<Value> newOuts;
+
+ OpBuilder::InsertionGuard g(rewriter);
+ // set insertPoint right before consumerOp
+ rewriter.setInsertionPoint(consumerOp);
+
+ for (auto [index, loop] :
+ llvm::enumerate(MutableArrayRef(outerLoops).drop_back())) {
+ if (index > 0)
+ rewriter.setInsertionPoint(loop);
+
+ LoopLikeOpInterface newLoopOp;
+ // Create a new loop with the new init values for this loop.
+ if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
+ newOuts = llvm::to_vector(forOp.getInits());
+ newOuts.append(newInitAppend.begin(), newInitAppend.end());
+ auto newLoop = rewriter.create<scf::ForOp>(
+ forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+ forOp.getStep(), newOuts);
+ newLoopOp = newLoop;
+ oldLoopBody = forOp.getBody();
+ newLoopBody = newLoop.getBody();
+ newInitAppend =
+ newLoopBody->getArguments().take_back(newInitAppend.size());
+ } else if (auto forallOp = dyn_cast<scf::ForallOp>(loop.getOperation())) {
+ newOuts = llvm::to_vector(forallOp.getOutputs());
+ newOuts.append(newInitAppend.begin(), newInitAppend.end());
+ auto newLoop = rewriter.create<scf::ForallOp>(
+ forallOp.getLoc(), forallOp.getMixedLowerBound(),
+ forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+ forallOp.getMapping());
+ rewriter.eraseOp(newLoop.getTerminator());
+ newLoopOp = newLoop;
+ oldLoopBody = forallOp.getBody();
+ newLoopBody = newLoop.getBody();
+
+ // create extractSliceOp for newInits
+ assert(index == 0 && "Currently Only outerMostLoop assumed ForallOp");
+ auto outerMostCandidate = candidateSliceOpList.back();
+ assert(isa<tensor::ParallelInsertSliceOp>(outerMostCandidate));
+ // set InsertPoint before next inner loop
+ auto nextLoop = outerLoops[index + 1];
+ rewriter.setInsertionPoint(nextLoop);
+ if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+ rewriter, cast<TilingInterface>(consumerOp), operandNumber,
+ outerMostCandidate, allResultOffsets, allResultSizes))) {
+ return failure();
+ }
+ fixSharedOutSCFForall(rewriter, newLoop, nextLoop, allResultOffsets,
+ allResultSizes, newInitAppend.size(),
+ newExtractOps);
+ newInitAppend = llvm::map_to_vector(
+ newExtractOps,
+ [](tensor::ExtractSliceOp op) -> Value { return op.getResult(); });
+ }
+ rewriter.mergeBlocks(
+ oldLoopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(oldLoopBody->getNumArguments()));
+ rewriter.replaceOp(
+ loop, newLoopOp->getResults().take_front(loop->getNumResults()));
+ newOuterLoops.push_back(newLoopOp);
+ }
+
+ // 5.a reconstruct inner-most loop.
+ LoopLikeOpInterface oldInnerMostLoop = outerLoops.back(), newInnerMostLoop;
+ Location loc = oldInnerMostLoop->getLoc();
+ if (outerLoops.size() > 1)
+ rewriter.setInsertionPoint(oldInnerMostLoop);
+
if (isInsertSliceOp) {
- auto forOp = cast<scf::ForOp>(oldLoopOp);
+ auto forOp = cast<scf::ForOp>(oldInnerMostLoop.getOperation());
+ newOuts = llvm::to_vector(forOp.getInits());
+ newOuts.append(newInitAppend.begin(), newInitAppend.end());
+ oldLoopBody = forOp.getBody();
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(), newOuts);
- newLoopOp = newForOp;
+ newInnerMostLoop = newForOp;
newLoopBody = newForOp.getBody();
} else {
- auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+ auto forallOp = cast<scf::ForallOp>(oldInnerMostLoop.getOperation());
+ newOuts = llvm::to_vector(forallOp.getOutputs());
+ newOuts.append(newInitAppend.begin(), newInitAppend.end());
+ oldLoopBody = forallOp.getBody();
auto newForallOp = rewriter.create<scf::ForallOp>(
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep(), newOuts, forallOp.getMapping());
- newLoopOp = newForallOp;
+ newInnerMostLoop = newForallOp;
rewriter.eraseOp(newForallOp.getTerminator());
newLoopBody = newForallOp.getBody();
}
- // 4. Move the loop body to the new op.
+ // 5.b Move the loop body to the new op.
unsigned oldNumArguments = oldLoopBody->getNumArguments();
rewriter.mergeBlocks(oldLoopBody, newLoopBody,
newLoopBody->getArguments().take_front(oldNumArguments));
+ // 5.c replace the result of old oldInnerMostLoop with newInnerMostLoop's
+ // results.
+ rewriter.replaceOp(oldInnerMostLoop,
+ newInnerMostLoop->getResults().take_front(
+ oldInnerMostLoop->getNumResults()));
- // 5. Set insertion point before terminator op of the loop and create a new
+ // 6. Set insertion point before terminator op of the loop and create a new
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
tensor::InsertSliceOp clonedInsertSliceOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
- auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+ auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
rewriter.setInsertionPoint(newForallOp.getTerminator());
- clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else {
rewriter.setInsertionPoint(candidateSliceOp);
- clonedInsertSliceOp =
- cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
}
-
- // 6.a. Clone consumer op.
- auto newForOpBlockArgsForConsumerDest =
- newLoopBody->getArguments().drop_front(oldNumArguments);
+ FailureOr<SmallVector<OpFoldResult>> realOffsets =
+ computeRealOffsetsCoordinatedRootSliceOp(rewriter, loc, ossSliceOp,
+ candidateSliceOpList);
+ if (failed(realOffsets))
+ return failure();
+ // create dummy insertSliceOp to align with the requirement of current
+ // Tiling interface and fix potential semantic mismatch with later
+ // extractSliceOp generated by `getTiledImplementation`.
----------------
Abhishek-Varma wrote:
```suggestion
// Step 7: Create dummy insertSliceOp to align with the requirement of
// current tiling interface and fix potential semantic mismatch with the
// extractSliceOp generated by `getTiledImplementation`.
```
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list