[Mlir-commits] [mlir] Extend `TilingInterface` to allow more flexible tiling (PR #95422)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 13 08:36:53 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Srinath Avadhanula (srinathava)
<details>
<summary>Changes</summary>
Ref: [discource thread](https://discourse.llvm.org/t/extending-tileconsumerandfuseproducer-to-handle-more-patterns/79340/2)
Problem:
The current version of `transform.structured.fuse` relies on ops implementing a `TilingInterface`. An op which implements such an interface returns a `TilingResult` [defined](https://sourcegraph.robot.car/github.robot.car/cruise/mla-robocomp-llvm-project/-/blob/mlir/include/mlir/Interfaces/TilingInterface.h?L26) as:
```c++
/// Container for result values of tiling.
/// - `tiledOps` contains operations created by the tiling implementation that
/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
/// untiled operation.
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
};
```
The way the algorithm is currently implemented, only the _last_ operation in `tiledOps` is considered for further fusion.
Where it breaks down is when we implement a `TilingInterface` for the `tosa.concat` operation like so (MLIR pseudo-code):
```mlir
%slice = scf.if (%offset < size(t1)) (
scf.yield tensor.extract_slice %arg1 ...
} else {
scf.yield tensor.extract_slice %arg2 ...
}
```
Even if both the `scf.yield` ops are returned in the `tiledOps` field, only the last one is further fused with upstream producers.
In this PR, we now extend `TilingResult` to contain a list of `tensor::ExtractSliceOps`. This allows the interface to directly return the list of slice ops it created to implement the tiled result. This required some plumbing of the tensor::ExtractSliceOps through TilingResult -> SCFTilingResult -> SCFFuseProducerOfSliceResult. This is then used to add to the worklist of extract slice ops which we process. This also required the current LinalgTilingInterface to provide the extractSliceOps.
---
Full diff: https://github.com/llvm/llvm-project/pull/95422.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+2)
- (modified) mlir/include/mlir/Interfaces/TilingInterface.h (+4)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+7-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+9-2)
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+19-18)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..fecd33193eb0d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,6 +85,7 @@ struct SCFTilingResult {
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
+ SmallVector<Operation *> extractSliceOps;
};
/// Method to tile an op that implements the `TilingInterface` using
@@ -135,6 +136,7 @@ struct SCFFuseProducerOfSliceResult {
OpResult origProducer; // Original untiled producer.
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
+ SmallVector<Operation *> extractSliceOps;
};
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index ca570490ccf5b..e5ed016d53fc1 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -28,9 +28,13 @@ namespace mlir {
/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
/// untiled operation.
+/// - `extractSliceOps` contains all the `tensor.extract_slice` ops used in
+/// generating the `tiledOps`. Usually these are operands to the `tiledOps`
+/// but they can be embedded in regions owned by `tiledOps`.
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
+ SmallVector<Operation *> extractSliceOps;
};
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..5198e0bceaa6e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2501,7 +2501,13 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ SmallVector<Operation *> sliceOps;
+ for (Value operand : tiledOperands)
+ if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+ sliceOps.push_back(sliceOp);
+
+ return TilingResult{
+ {tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
}
LogicalResult SoftmaxOp::getResultTilePosition(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..f25ccc38ba0a3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -129,7 +129,13 @@ struct LinalgOpTilingInterface
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ SmallVector<Operation *> sliceOps;
+ for (Value operand : tiledOperands)
+ if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+ sliceOps.push_back(sliceOp);
+
+ return TilingResult{
+ {tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
}
/// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -247,7 +253,8 @@ struct LinalgOpTilingInterface
return TilingResult{
tilingResult->tiledOps,
- SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
+ SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
+ tilingResult->extractSliceOps};
}
/// Method to generate the tiled implementation of an operation from the tile
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..fb3ec2a5fa0a8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -619,7 +619,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
if (llvm::all_of(tileSizes, isZeroIndex)) {
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
tilingResult =
- TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+ TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
+ /*extractSliceOps=*/{}};
return success();
}
@@ -675,12 +676,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// op.
if (loops.empty()) {
return scf::SCFTilingResult{tilingResult->tiledOps, loops,
- tilingResult->tiledValues};
+ tilingResult->tiledValues,
+ tilingResult->extractSliceOps};
}
SmallVector<Value> replacements = llvm::map_to_vector(
loops.front()->getResults(), [](OpResult r) -> Value { return r; });
- return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
+ return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+ tilingResult->extractSliceOps};
}
FailureOr<scf::SCFReductionTilingResult>
@@ -931,9 +934,9 @@ mlir::scf::tileAndFuseProducerOfSlice(
->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
.set(origDestinationTensors[resultNumber]);
}
- return scf::SCFFuseProducerOfSliceResult{fusableProducer,
- tileAndFuseResult->tiledValues[0],
- tileAndFuseResult->tiledOps};
+ return scf::SCFFuseProducerOfSliceResult{
+ fusableProducer, tileAndFuseResult->tiledValues[0],
+ tileAndFuseResult->tiledOps, tileAndFuseResult->extractSliceOps};
}
/// Reconstruct the fused producer from within the tiled-and-fused code.
@@ -962,13 +965,12 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
.getDefiningOp<DestinationStyleOpInterface>()) {
rewriter.setInsertionPoint(tiledDestStyleOp);
Value newRegionArg = newRegionIterArgs.back();
- auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
- sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
unsigned resultNumber = fusableProducer.getResultNumber();
- rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
- tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
- });
+ auto origSlice = tiledDestStyleOp.getDpsInits()[resultNumber]
+ .getDefiningOp<tensor::ExtractSliceOp>();
+ if (origSlice) {
+ origSlice.getSourceMutable().set(newRegionArg);
+ }
}
Block *block = rewriter.getInsertionPoint()->getBlock();
rewriter.setInsertionPoint(block->getTerminator());
@@ -1036,15 +1038,14 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// operations. If the producers of the source of the `tensor.extract_slice`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
- auto addCandidateSlices = [](Operation *fusedOp,
+ auto addCandidateSlices = [](const SmallVector<Operation *> &newSliceOps,
std::deque<tensor::ExtractSliceOp> &candidates) {
- for (Value operand : fusedOp->getOperands())
- if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
- candidates.push_back(sliceOp);
+ for (auto *op : newSliceOps)
+ candidates.push_back(llvm::cast<tensor::ExtractSliceOp>(op));
};
std::deque<tensor::ExtractSliceOp> candidates;
- addCandidateSlices(tiledAndFusedOps.back(), candidates);
+ addCandidateSlices(tilingResult->extractSliceOps, candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty()) {
// Traverse the slices in BFS fashion.
@@ -1086,7 +1087,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
tiledAndFusedOps.insert(tiledAndFusedOp);
- addCandidateSlices(tiledAndFusedOp, candidates);
+ addCandidateSlices(fusedResult->extractSliceOps, candidates);
}
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/95422
More information about the Mlir-commits
mailing list