[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #108318)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 11 19:42:52 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: None (Yun-Fly)
<details>
<summary>Changes</summary>
This is a mirror PR of #<!-- -->94190 with tiny build fix.
Sorry for your inconvenience.
---
Patch is 28.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108318.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+174-174)
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+70-7)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..f4cf92201068ae 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1481,6 +1481,50 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
return &operand;
}
+/// Find the perfectly nested loops outside of given loop(included) sorted from
+/// outer to inner.
+///
+/// E.g.
+///
+/// ```
+/// %0 = scf.for()
+/// %1 = scf.for()
+/// %2 = scf.for()
+/// %3 = ...
+/// yield %3
+/// yield %2
+/// yield %1
+/// ```
+///
+/// This function will return three perfectly nested loops: %0 + %1 + %2, when
+/// target inner loop is %2.
+static SmallVector<scf::ForOp>
+getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
+ SmallVector<scf::ForOp> nestLoops = {loop};
+ auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
+
+ // Check if it is the ForOp that yield the result of inner loop.
+ auto isForOpYieldResultOfInnerLoop =
+ [](scf::ForOp outerLoop) -> LogicalResult {
+ Block *body = outerLoop.getBody();
+ if (!llvm::hasSingleElement(body->without_terminator()))
+ return failure();
+ auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
+ auto innerForOp = dyn_cast<scf::ForOp>(body->front());
+ if (!innerForOp)
+ return failure();
+ // All of innerForOp results should be yielded.
+ return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
+ };
+
+ while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
+ nestLoops.push_back(outerLoop);
+ outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
+ }
+ // sorted from outer to inner
+ return {nestLoops.rbegin(), nestLoops.rend()};
+}
+
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
/// tensor.insert_slice. This function makes the following assumptions :
/// 1. tensor.insert_slice has scf.yield as its only user.
@@ -1498,9 +1542,10 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
auto forOp = dyn_cast<scf::ForOp>(containingOp);
if (!forOp)
return failure();
- Value resultingValue = forOp->getResult(resultNumber);
+ scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
+ Value resultingValue = topLevelForOp->getResult(resultNumber);
- return getConsumerFromUses(resultingValue, containingOp->getBlock());
+ return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
}
/// Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1563,59 +1608,6 @@ static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
}
}
-/// After fusing consumer into scf.for we want to modify the scf.yield operation
-/// to reflect the same by returning the values yielded by the tiled consumer.
-static void
-fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
- TilingResult &tilingResult,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
- ArrayRef<BlockArgument> bbArgs) {
- scf::YieldOp oldTerminatorOp =
- cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
- unsigned totalOldResults = oldTerminatorOp->getNumResults();
- unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
- SmallVector<Value> newYieldOperands;
- newYieldOperands.reserve(totalOldResults + totalTiledResults);
- for (auto oldResult : oldTerminatorOp.getResults()) {
- newYieldOperands.push_back(oldResult);
- }
- rewriter.setInsertionPointAfter(oldTerminatorOp);
- Location loc = newForOp.getLoc();
- for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
- resultOffsets, resultSizes)) {
- SmallVector<OpFoldResult> strides(resultOffset.size(),
- rewriter.getIndexAttr(1));
- Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, tiledResult, bbArg, resultOffset, resultSize, strides);
- newYieldOperands.push_back(newInsertSliceOp);
- }
- rewriter.create<scf::YieldOp>(loc, newYieldOperands);
- rewriter.eraseOp(oldTerminatorOp);
-}
-
-/// After fusing consumer into scf.forall we want to yield each of the resulting
-/// values by the tiled consumer within scf.forall.in_parallel region.
-static void
-fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
- SmallVector<Value> tiledResults,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
- ArrayRef<BlockArgument> bbArgs) {
- scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
- rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
- Location firstYieldOpLoc =
- (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
- for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
- SmallVector<OpFoldResult> strides(resultOffset.size(),
- rewriter.getIndexAttr(1));
- rewriter.create<tensor::ParallelInsertSliceOp>(
- firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
- }
-}
-
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1646,81 +1638,63 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
- Operation *oldLoopOp = nullptr;
- SmallVector<Value> newOuts;
- Block *oldLoopBody = nullptr;
- unsigned initSize = 0;
- unsigned rank = 1;
+ // There are two possible cases regarding `oldLoopOp` here:
+ // 1. single `scf.forall` or `scf.for`.
+ // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
+ // top-level loop is the outer-most one of these nested loops.
+ LoopLikeOpInterface innerMostLoop =
+ candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
+ SmallVector<LoopLikeOpInterface> nestedLoops;
if (isInsertSliceOp) {
- auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
- oldLoopOp = forOp;
- llvm::append_range(newOuts, forOp.getInits());
- oldLoopBody = forOp.getBody();
- initSize = forOp.getInits().size();
+ nestedLoops = llvm::map_to_vector(
+ getPerfectlyNestedLoopsOutsideOf(
+ cast<scf::ForOp>(innerMostLoop.getOperation())),
+ [](scf::ForOp forOp) {
+ return cast<LoopLikeOpInterface>(forOp.getOperation());
+ });
} else {
- auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
- oldLoopOp = forallOp;
- llvm::append_range(newOuts, forallOp.getOutputs());
- oldLoopBody = forallOp.getBody();
- initSize = forallOp.getOutputs().size();
- rank = forallOp.getRank();
+ nestedLoops = {innerMostLoop};
}
- if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ LoopLikeOpInterface outerMostLoop = nestedLoops.front();
+
+ if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
return rewriter.notifyMatchFailure(
- oldLoopOp, "containing loop op should either yield just one value or "
- "have the consumer op as its first user");
+ outerMostLoop,
+ "containing loop op should either yield just one value or "
+ "have the consumer op as its first user");
}
OpBuilder::InsertionGuard g(rewriter);
// 2. Check consumer is not using scf loop's output as init.
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp)
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer op is not DPS operation");
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
- newOuts.append(dpsInits);
-
- Location loc = oldLoopOp->getLoc();
+ SmallVector<Value> newInits = dpsInits;
- // 3. Create new scf loop op.
- rewriter.setInsertionPoint(consumerOp);
- Operation *newLoopOp = nullptr;
- Block *newLoopBody = nullptr;
- if (isInsertSliceOp) {
- auto forOp = cast<scf::ForOp>(oldLoopOp);
- auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
- forOp.getUpperBound(),
- forOp.getStep(), newOuts);
- newLoopOp = newForOp;
- newLoopBody = newForOp.getBody();
- } else {
- auto forallOp = cast<scf::ForallOp>(oldLoopOp);
- auto newForallOp = rewriter.create<scf::ForallOp>(
- loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
- forallOp.getMixedStep(), newOuts, forallOp.getMapping());
- newLoopOp = newForallOp;
- rewriter.eraseOp(newForallOp.getTerminator());
- newLoopBody = newForallOp.getBody();
- }
+ Location loc = outerMostLoop->getLoc();
- // 4. Move the loop body to the new op.
- unsigned oldNumArguments = oldLoopBody->getNumArguments();
- rewriter.mergeBlocks(oldLoopBody, newLoopBody,
- newLoopBody->getArguments().take_front(oldNumArguments));
+ // 3. Move the whole loop structure right before consumer Op, the dominance
+ // should be already ensured by `checkAssumptionForLoop`.
+ rewriter.moveOpBefore(outerMostLoop, consumerOp);
- // 5. Set insertion point before terminator op of the loop and create a new
+ // 4. Set insertion point before terminator op of the loop and create a new
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
tensor::InsertSliceOp clonedInsertSliceOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
- auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+ auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
rewriter.setInsertionPoint(newForallOp.getTerminator());
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
@@ -1731,20 +1705,17 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
}
- // 6.a. Clone consumer op.
- auto newForOpBlockArgsForConsumerDest =
- newLoopBody->getArguments().drop_front(oldNumArguments);
- auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
- rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+ // 5.a. Clone consumer op.
+ auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
- // 6.b. Replace all uses of the loop result with the result of the cloned
+ // 5.b. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice.
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
operandToReplace.set(clonedInsertSliceOp.getResult());
});
- // 7 - Perform tiling of the cloned consumer and replace the operand at
+ // 6. Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op.
auto ossSliceOp =
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
@@ -1754,79 +1725,108 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
if (failed(tileAndFuseResult)) {
return failure();
}
- rewriter.replaceAllUsesWith(
- tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
- clonedInsertSliceOp.getSource());
-
- // 8 - Extract offset/sizes/strides required to create the
- // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
- SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
- // 9. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
+ auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
+ rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
+ clonedInsertSliceOp.getSource());
- // 10. Try to get iter domain position from input position.
- SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
- iterDomainSizes))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp, "can't get iter domain position from input position");
- }
+ // 7. Reconstruct [nested] loop with new inits.
+ YieldTiledValuesFn newYieldValuesFn =
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+ OpBuilder::InsertionGuard g(innerRewriter);
+ // 8. Set inner insertPoint right before tiled consumer op.
+ innerRewriter.setInsertionPoint(tiledConsumerOp);
- // 11. Try to fetch the offset and size for all results of the cloned
- // consumer. This would then be used to form the corresponding
- // tensor.insert_slice/parallel_insert_slice later.
- unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
- SmallVector<SmallVector<OpFoldResult>> resultOffsets(
- totalNumResultsOfConsumer);
- SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
- for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
- if (failed(clonedConsumerOp.getResultTilePosition(
- rewriter, idx, iterDomainOffsets, iterDomainSizes,
- resultOffsets[idx], resultSizes[idx]))) {
+ SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+
+ // 9. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
return rewriter.notifyMatchFailure(
- clonedConsumerOp,
- "can't get result domain position from iter domain position");
+ candidateSliceOp, "containingOp's result yield with stride");
}
- }
- auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
- auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
- if (isInsertSliceOp) {
- auto newForOp = cast<scf::ForOp>(newLoopOp);
- fixTerminatorSCFYield(
- rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
- newForOp.getBody()->getArguments().drop_front(1 + initSize));
- } else {
- auto newForallOp = cast<scf::ForallOp>(newLoopOp);
- fixTerminatorSCFInParallel(
- rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
- arrayRefOffsets, arrayRefSizes,
- newForallOp.getBody()->getArguments().drop_front(rank + initSize));
- }
+ // 10. Try to get iter domain position from input position.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ tiledConsumerOp,
+ "can't get iter domain position from input position");
+ }
- // 12. Replace the result of scf loop and consumer op with new loop's results.
- for (auto &&[oldResult, newResult] :
- llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
- rewriter.replaceAllUsesWith(oldResult, newResult);
+ // 11. Try to fetch the offset and size for all results of the cloned
+ // consumer. This would then be used to form the corresponding
+ // tensor.insert_slice/parallel_insert_slice later.
+ unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(
+ totalNumResultsOfConsumer);
+ for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
+ if (failed(tiledConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ tiledConsumerOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
+
+ // 12. Create `extract_slice` for `iter_args` for DPS operation if
+ // necessary.
+ if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
+ tiledConsumerOp.getOperation())) {
+ rewriter.setInsertionPoint(tiledDestStyleOp);
+ for (const auto &&[index, newRegionArg] :
+ llvm::enumerate(newRegionIterArgs)) {
+ auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
+ loc, newRegionArg, resultOffsets[index], resultSizes[index],
+ SmallVector<OpFoldResult>(resultOffsets[index].size(),
+ rewriter.getIndexAttr(1)));
+ // Make C++ 17 happy, otherwise it will throw error `captured structured
+ // bindings are a C++20 extension`.
+ auto dstNumber = index;
+ rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
+ tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
+ });
+ }
+ }
+
+ // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
+ // caller.
+ Block *block = rewriter.getInsertionPoint()->getBlock();
+ rewriter.setInsertionPoint(block->getTerminator());
+ for (const auto &&[index, result] :
+ llvm::enumerate(tiledConsumerOp->getResults())) {
+ tiledResult.push_back(result);
+ tiledOffset.emplace_back(resultOffsets[index]);
+ tiledSizes.emplace_back(resultSizes[index]);
+ }
+ return success();
+ };
+ // 14. Add new inits to [nested] loops.
+ if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
+ newYieldValuesFn))) {
+ return rewriter.notifyMatchFailure(tiledConsumerOp,
+ "unable to add new inits to nest loop");
}
- for (auto &&[oldResult, newResult] :
- llvm::zip(consumerOp->getResults(),
- newLoopOp->getResults().drop_front(initSize))) {
+ // 15. Replace the result of scf loop and consumer op with new loop's results.
+
+ for (auto &&[oldResult, newResult] : llvm::zip(
+ consumerOp->getResults(),
+ nestedLoops.front()->getResults().take_back(newInits.size()))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
- // 13. Need to erase the old scf loop and the cloned consumer op.
- rewriter.eraseOp(oldLoopOp);
+ // 16. Need to erase the old scf loop and the cloned consumer op.
rewriter.eraseOp(clonedConsumerOp);
return scf::SCF...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/108318
More information about the Mlir-commits
mailing list