[Mlir-commits] [mlir] [mlir][TilingInterface] Make `tileAndFuseConsumerOfSlice` take surrounding loops as an argument. (PR #132082)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 19 11:51:13 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (MaheshRavishankar)
<details>
<summary>Changes</summary>
This gets the consumer fusion method in sync with the corresponding producer fusion method `tileAndFuseProducerOfSlice`. Not taking this as input required use of complicated analysis to retrieve the surrounding loops which are very fragile. Just like the producer fusion method, the loops need to be taken in as an argument, with typically the loops being created by the tiling methods.
Some utilities are added to check that the loops passed in are perfectly nested (in the case of an `scf.for` loop nest.
This is change 1 of N to simplify the implementation of tile and fuse consumers.
---
Full diff: https://github.com/llvm/llvm-project/pull/132082.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+2-1)
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+105-47)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index d2cddfe00ac78..33a43ce2ee7bb 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -328,7 +328,8 @@ struct SCFFuseConsumerOfSliceResult {
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops);
/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index af87fb7a79d04..4fd10b0e30ab0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1890,25 +1890,81 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
return {nestLoops.rbegin(), nestLoops.rend()};
}
+/// Check that the loop is perfectly nested.
+static bool
+isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected empty loop nest");
+ if (loops.size() == 1) {
+ return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
+ }
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+ auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
+ auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
+ if (!outerFor || !innerFor) {
+ return false;
+ }
+ auto outerBBArgs = outerFor.getRegionIterArgs();
+ auto innerIterArgs = innerFor.getInitArgs();
+ if (outerBBArgs.size() != innerIterArgs.size()) {
+ return false;
+ }
+
+ for (auto [outerBBArg, innerIterArg] :
+ llvm::zip(outerBBArgs, innerIterArgs)) {
+ if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
+ innerIterArg != outerBBArg) {
+ return false;
+ }
+ }
+
+ auto outerYields =
+ cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
+ auto innerResults = innerFor.getResults();
+ if (outerYields.size() != innerResults.size()) {
+ return false;
+ }
+ for (auto [outerYield, innerResult] :
+ llvm::zip(outerYields, innerResults)) {
+ if (!llvm::hasSingleElement(innerResult.getUses()) ||
+ outerYield != innerResult) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
/// tensor.insert_slice. This function makes the following assumptions :
/// 1. tensor.insert_slice has scf.yield as its only user.
/// 2. scf.for's corresponding result has only one use.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter,
- tensor::InsertSliceOp candidateSliceOp) {
+ tensor::InsertSliceOp candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected loops to be empty");
+ // 1. Expect slice to be part of the body of the inner most loop.
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ if (containingOp != loops.back()) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp,
+ "expected slice to be within body of inner-most loop");
+ }
+
+ if (!isPerfectlyNestedForLoops(loops)) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "expected passed loops to be perfectly nested.");
+ }
+
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
return failure();
Value sliceResult = candidateSliceOp.getResult();
// Step 1. Fetch the corresponding output.
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
unsigned resultNumber = yieldOpOperand.getOperandNumber();
- // Step 2. Check containing op is scf.for.
- Operation *containingOp = candidateSliceOp->getParentOp();
- auto forOp = dyn_cast<scf::ForOp>(containingOp);
- if (!forOp)
- return failure();
- scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
+
+ scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
}
@@ -1917,35 +1973,49 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
/// by a tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter,
- tensor::ParallelInsertSliceOp candidateSliceOp) {
- // Step 1. Fetch the corresponding output
+ tensor::ParallelInsertSliceOp candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected loops to be empty");
+ // 1. Check that the surrounding loop is a single scf.forall loop.
+ if (loops.size() != 1) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "expected single surrounding scf.forall");
+ }
+ auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
+ if (!forallOp) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "expected single surrounding scf.forall");
+ }
+
+ // 2. Fetch the corresponding output
Value sliceDest = candidateSliceOp.getDest();
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
if (!iterArg)
return failure();
- Operation *containingOp = iterArg.getOwner()->getParentOp();
- if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
- return failure();
- // Step 2. Check that the containing op is scf.forall.
- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
- if (!forallOp)
+ if (iterArg.getOwner()->getParentOp() != forallOp)
return failure();
+
unsigned resultNumber =
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
.getResultNumber();
- return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
+ return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
}
/// A utility to fetch an untiled consumer of
/// tensor.insert_slice/tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
+getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ if (loops.empty()) {
+ return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops");
+ }
+
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(rewriter, insertSlice);
+ return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
} else if (auto parallelInsertSlice =
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
+ return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
} else {
return failure();
}
@@ -1954,18 +2024,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
- Operation *candidateSliceOp) {
+mlir::scf::tileAndFuseConsumerOfSlice(
+ RewriterBase &rewriter, Operation *candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ // Return if `loops` is empty, return an error for now. Caller is expected
+ // to handle this case.
+ if (loops.empty()) {
+ return candidateSliceOp->emitOpError(
+ "cannot call tile and fuse consumer with an empty loop nest");
+ }
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))
return failure();
- bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
-
// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
FailureOr<OpOperand *> maybeConsumerOpOperand =
- getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
+ getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
@@ -1981,25 +2056,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
- // There are two possible cases regarding `oldLoopOp` here:
- // 1. single `scf.forall` or `scf.for`.
- // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
- // top-level loop is the outer-most one of these nested loops.
- LoopLikeOpInterface innerMostLoop =
- candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
- SmallVector<LoopLikeOpInterface> nestedLoops;
- if (isInsertSliceOp) {
- nestedLoops = llvm::map_to_vector(
- getPerfectlyNestedLoopsOutsideOf(
- cast<scf::ForOp>(innerMostLoop.getOperation())),
- [](scf::ForOp forOp) {
- return cast<LoopLikeOpInterface>(forOp.getOperation());
- });
- } else {
- nestedLoops = {innerMostLoop};
- }
-
- LoopLikeOpInterface outerMostLoop = nestedLoops.front();
+ LoopLikeOpInterface outerMostLoop = loops.front();
+ LoopLikeOpInterface innerMostLoop = loops.back();
// Check assumption for loop with `reorderOperations` disabled.
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
@@ -2165,7 +2223,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
return success();
};
// 14. Add new inits to [nested] loops.
- if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
+ if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
newYieldValuesFn))) {
return rewriter.notifyMatchFailure(tiledConsumerOp,
"unable to add new inits to nest loop");
@@ -2174,9 +2232,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// 15. Replace the result of scf loop and consumer op with new loop's
// results.
- for (auto &&[oldResult, newResult] : llvm::zip(
- consumerOp->getResults(),
- nestedLoops.front()->getResults().take_back(newInits.size()))) {
+ for (auto &&[oldResult, newResult] :
+ llvm::zip(consumerOp->getResults(),
+ loops.front()->getResults().take_back(newInits.size()))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/132082
More information about the Mlir-commits
mailing list