[Mlir-commits] [mlir] [mlir][TilingInterface] Handle multi operand consumer fusion. (PR #145193)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jun 21 16:14:14 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: None (MaheshRavishankar)
<details>
<summary>Changes</summary>
For consumer fusion cases of this form
```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {
tensor.parallel_insert_slice ... into %arg0
tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#<!-- -->0, %0#<!-- -->1)
```
the current consumer fusion that handles one slice at a time cannot fuse the consumer into the loop, since fusing along one slice will create and SSA violation on the other use from the `scf.forall`. The solution is to allow consumer fusion to allow considering multiple slices at once. This PR changes the `TilingInterface` methods related to consumer fusion, i.e.
- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`
to allow fusion while considering multiple operands. It is upto the `TilingInterface` implementation to return an error if a list of tiles of the operands cannot result in a consistent implementation of the tiled operation.
The Linalg operation implementation of `TilingInterface` has been modified to account for these changes and allow cases where operand tiles that can result in a consistent tiling implementation are handled.
Additional change : Add `LLVM_DUMP_METHOD` to `OpFoldResult::dump` to preserve the symbol in Debug builds.
---
Patch is 60.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145193.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+14-9)
- (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (+9-5)
- (modified) mlir/include/mlir/IR/OpDefinition.h (+1-1)
- (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+27-26)
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+118-64)
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+142-67)
- (modified) mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp (+37-10)
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+289)
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+27-20)
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+4-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index f686ae07b9a99..7b6e3cba5723d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -319,19 +319,24 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer,
const SCFTileAndFuseOptions &options);
-/// Fuse the consumer of the source of `candidateSliceOp` by computing the
-/// required slice of the consumer in-place. Note that the method
-/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
-/// value but does not delete the slice operation.
+/// Fuse the consumer of the result of every element of `candidateSliceOp` by
+/// computing the required slice of the consumer in-place. All the entries of
+/// `candidateSlices` are expected to map to the same consumer. The method
+/// returns an error if the consumer cannot be tiled in a manner that is
+/// consistent for all the passed slices. Note that the method replaces the uses
+/// of `candidateSliceOp` with the tiled and fused consumer value but does not
+/// delete the slice operation.
struct SCFFuseConsumerOfSliceResult {
- OpOperand *origConsumerOperand; // Original untiled consumer's operand.
- OpOperand
- *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
+ // Original untiled consumer's operand.
+ SmallVector<OpOperand *> origConsumerOperands;
+ // Tiled and fused consumer's operand.
+ SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
- MutableArrayRef<LoopLikeOpInterface> loops);
+tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
+ ArrayRef<Operation *> candidateSlices,
+ MutableArrayRef<LoopLikeOpInterface> loops);
/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 18981337742eb..8f6eb1bd47782 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -31,12 +31,16 @@ namespace tensor {
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
-/// Method to swap an `tensor.insert_slice` with its consumer when the
-/// consumer implements the `TilingInterface`.
+/// Method to swap an `tensor.insert_slice`s with its consumer when the
+/// consumer implements the `TilingInterface`. The size of `sliceOps` and
+/// `consumerOperands` is expected to be the same. Every entry in
+/// `consumerOperands` represents the use of the result of the corresponding
+/// entry in `sliceOps`. All entries of `consumerOperands` is expected to be
+/// uses in the same consumer.
FailureOr<TilingResult>
-replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
- OffsetSizeAndStrideOpInterface sliceOp,
- OpOperand &consumerOp);
+replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
+ ArrayRef<tensor::InsertSliceOp> sliceOps,
+ ArrayRef<OpOperand *> consumerOperands);
//===----------------------------------------------------------------------===//
// Populate functions.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 31f54413a5ff0..663c256c848df 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
using PointerUnion<Attribute, Value>::PointerUnion;
public:
- void dump() const { llvm::errs() << *this << "\n"; }
+ LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }
MLIRContext *getContext() const {
PointerUnion pu = *this;
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index cdf3d01ce8a84..7ebdd8907e964 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
InterfaceMethod<
/*desc=*/[{
Method to generate the tiled implementation of an operation that uses
- exactly a tile of the given operand.
+ exactly tiles of the given operands.
This method is required to allow operations to be "tiled and fused"
- with an (already tiled) producer. Given a tile of the producer, this
- method generates the tile of the consumer that uses exactly this
- produced tile. In some sense it is the "reverse" of
+ with an (already tiled) producer. Given a tiles of the producer, this
+ method generates the tile of the consumer that uses exactly these
+ produced tiles. In some sense it is the "reverse" of
`generateResultTileValue`.
- - `operandNumber` is the result of the producer used by the consumer.
- - `offsets` is the offset of the slice of the producer result used by
- the tiled implementation of the consumer.
- - `sizes` is the size of the slice of the producer result used by the
+ - `operandNumbers` is the list of operands whose tiles are "produced".
+ - `allOffsets` is the offset of the slice of the producer results used
+ by the tiled implementation of the consumer.
+ - `allSizes` is the size of the slice of the producer results used by the
consumer.
- If it is illegal to fuse with a producer along the given operand for
+ If it is illegal to fuse with a producer along the given operand tiles for
an operation, the implementation should return a failure.
}],
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
- /*methodName=*/"getTiledImplementationFromOperandTile",
+ /*methodName=*/"getTiledImplementationFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
- "unsigned":$operandNumber,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
+ "::mlir::ArrayRef<unsigned>":$operandNumbers,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -235,13 +235,14 @@ def TilingInterface : OpInterface<"TilingInterface"> {
tile of the operand.
This method is required to allow operations to be "tiled and fused"
- with an (already tiled) producer. Given a tile of an operand,
- returns the tile of the iteration space that uses this tile.
- - `operandNumber` is the result of the producer used by the consumer.
- - `offsets` is the offset of the slice of the producer result used by
- the tiled implementation of the consumer.
- - `sizes` is the size of the slice of the producer result used by the
- consumer.
+ with an (already tiled) producer. Given tiles of an operand,
+ returns the tile of the iteration space that uses these tiles.
+ - `operandNumbers` is the list of operands whose tiles are "produced"
+ by the producer(s).
+ - `allOffsets` is the offset of the slice of the producer results
+ used by the tiled implementation of the consumer.
+ - `allSizes` is the size of the slice of the producer results used by
+ the consumer.
If it is illegal to fuse with a producer along the given operand for
an operation, or if this mapping cannot be computed, the
implementation should return a failure.
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformation. It does not provide guarantees on whether such a
transformation is profitable.
- For most cases `getTiledImplementationFromOperandTile` could be a
- implemented using `getIterationDomainTileFromOperandTile` +
+ For most cases `getTiledImplementationFromOperandTiles` could be a
+ implemented using `getIterationDomainTileFromOperandTiles` +
`getTiledImplementation` methods.
}],
/*retType=*/"::llvm::LogicalResult",
- /*methodName=*/"getIterationDomainTileFromOperandTile",
+ /*methodName=*/"getIterationDomainTileFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
- "unsigned":$operandNumber,
- "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
- "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+ "::mlir::ArrayRef<unsigned>":$operandNumbers,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..86045b54075bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -147,55 +147,80 @@ struct LinalgOpTilingInterface
/// Utility to fetch the offsets and sizes when applied as per the indexing
/// map of the linalg op. This helps in fusing the linalg op as a consumer of
/// a given slice op.
- void
- getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- SmallVectorImpl<OpFoldResult> &mappedOffsets,
- SmallVectorImpl<OpFoldResult> &mappedSizes) const {
- unsigned numLoops = linalgOp.getNumLoops();
- auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
- mappedOffsets.resize(numLoops);
- mappedSizes.resize(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
- mappedOffsets[index] = value.offset;
- mappedSizes[index] = value.size;
+ static LogicalResult
+ getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
+ ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
+ SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
+ SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
+ DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
+
+ for (auto [indexingMap, offsets, sizes] :
+ llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
+ for (auto [resultExpr, offset, size] :
+ llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
+ if (!dimExpr)
+ continue;
+ unsigned position = dimExpr.getPosition();
+ auto it = mappedOffsets.find(position);
+ if (it != mappedOffsets.end()) {
+ OpFoldResult seenOffset = it->second;
+ OpFoldResult seenSize = mappedSizes.lookup(position);
+ if (seenOffset != offset || seenSize != size) {
+ return linalgOp->emitOpError(
+ "inconsistent iteration space mapping from offsets/sizes of "
+ "operands/results");
+ }
+ } else {
+ mappedOffsets[position] = offset;
+ mappedSizes[position] = size;
+ }
}
}
- for (const auto &&[index, value] :
- llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
- mappedOffsets[dimPosition] = offsets[index];
- mappedSizes[dimPosition] = sizes[index];
+
+ // Aggregate from the given operand offsets and sizes, or default to
+ // iteration space values.
+ SmallVector<Range> iterationDomain =
+ cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
+ mappedOffsetsVec.resize(iterationDomain.size());
+ mappedSizesVec.resize(iterationDomain.size());
+ for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
+ auto it = mappedOffsets.find(index);
+ if (it != mappedOffsets.end()) {
+ mappedOffsetsVec[index] = it->second;
+ mappedSizesVec[index] = mappedSizes.lookup(index);
+ continue;
+ }
+ mappedOffsetsVec[index] = domain.offset;
+ mappedOffsetsVec[index] = domain.size;
}
+ return success();
}
/// Method to return the position of the result tile computed by the tiled
/// operation.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);
- // Check that the indexing map used for the operand is a projected
- // permutation. This could be relaxed with a more general approach that can
- // map the offsets and sizes from the operand to iteration space tiles
- // (filling in full extent for dimensions not used to access the result).
- AffineMap indexingMap =
- linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
- if (!indexingMap.isProjectedPermutation()) {
- return op->emitError()
- << "unhandled get iter domain position when operand is not "
- "accessed using a permuted projection";
+ std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
+ iterationSpaceSizes;
+ SmallVector<AffineMap> indexingMaps =
+ llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+ OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+ return linalgOp.getMatchingIndexingMap(&opOperand);
+ });
+ if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
+ allSizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return failure();
}
-
- getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
- iterDomainOffsets, iterDomainSizes);
return success();
}
@@ -246,8 +271,13 @@ struct LinalgOpTilingInterface
"accessed using a permuted projection");
}
- getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
- iterDomainOffsets, iterDomainSizes);
+ SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
+ SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
+ auto status =
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
+ {allSizes}, iterDomainOffsets, iterDomainSizes);
+ (void)status;
+ assert(succeeded(status) && "unexpected error in offset calculation");
return success();
}
@@ -278,12 +308,13 @@ struct LinalgOpTilingInterface
/// Method to generate the tiled implementation of an operation from the tile
/// of the operand.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
- if (failed(getIterationDomainTileFromOperandTile(
- op, b, operandNumber, offsets, sizes, mappedOffsets,
+ if (failed(getIterationDomainTileFromOperandTiles(
+ op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
mappedSizes))) {
return failure();
}
@@ -750,13 +781,17 @@ struct PackOpTiling
/// Method to return the position of iteration domain tile computed by the
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
/// `resultSizes` only cover outer dimensions.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
- if (operandNumber != 0)
- return failure();
+ if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+ return op->emitOpError("unsupporeted operands for consumer fusion");
+
+ ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+ ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
// It is not trivial to infer dest tile from source tile if `packOp` has
@@ -817,11 +852,15 @@ struct PackOpTiling
}
/// Method to return the tiled implementation of tensor.pack as a consumer.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
- if (operandNumber != 0)
- return failure();
+ FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+ if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+ return op->emitOpError("unhandled operands for consumer fusion");
+
+ ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+ ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
Location loc = packOp.getLoc();
@@ -836,8 +875,8 @@ struct PackOpTiling
tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
- if (failed(getIterationDomainTileFromOperandTile(
- op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+ if (failed(getIterationDomainTileFromOperandTiles(
+ op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
outerDimSizes)))
return failure();
@@ -1095,12 +1134,20 @@ struct UnPackOpTiling
/// Method to return the position of iteration domain tile computed by the
/// tiled operation.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpB...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/145193
More information about the Mlir-commits
mailing list