[Mlir-commits] [mlir] [mlir][TilingInterface] Handle multi operand consumer fusion. (PR #145193)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jun 22 21:23:04 PDT 2025
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/145193
>From c86e95b1cc1f076a69cc86850174106c3f8c664a Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 18 Jun 2025 17:05:43 -0700
Subject: [PATCH] [mlir][TilingInterface] Handle multi operand consumer fusion.
For consumer fusion cases of this form
```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {
tensor.parallel_insert_slice ... into %arg0
tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#0, %0#1)
```
the current consumer fusion that handles one slice at a time cannot
fuse the consumer into the loop, since fusing along one slice will
create and SSA violation on the other use from the `scf.forall`. The
solution is to allow consumer fusion to allow considering multiple
slices at once. This PR changes the `TilingInterface` methods related
to consumer fusion, i.e.
- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`
to allow fusion while considering multiple operands. It is upto the
`TilingInterface` implementation to return an error if a list of tiles
of the operands cannot result in a consistent implementation of the
tiled operation.
The Linalg operation implementation of `TilingInterface` has been
modified to account for these changes and allow cases where operand
tiles that can result in a consistent tiling implementation are handled.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
.../SCF/Transforms/TileUsingInterface.h | 22 +-
.../Dialect/Tensor/Transforms/Transforms.h | 14 +-
.../mlir/Interfaces/TilingInterface.td | 55 ++--
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 182 +++++++----
.../SCF/Transforms/TileUsingInterface.cpp | 209 +++++++++----
.../SwapExtractSliceWithProducerPatterns.cpp | 47 ++-
.../tile-and-fuse-consumer.mlir | 289 ++++++++++++++++++
.../TestTilingInterfaceTransformOps.cpp | 47 +--
.../TestTilingInterfaceTransformOps.td | 7 +-
9 files changed, 667 insertions(+), 205 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index f686ae07b9a99..c3b79fc860208 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -319,19 +319,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer,
const SCFTileAndFuseOptions &options);
-/// Fuse the consumer of the source of `candidateSliceOp` by computing the
-/// required slice of the consumer in-place. Note that the method
-/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
-/// value but does not delete the slice operation.
+/// Fuse the consumer `candidateSlices` by computing the required slice of the
+/// consumer in-place. All the entries of `candidateSlices` are expected to map
+/// to the same consumer. The method returns an error if the consumer cannot be
+/// tiled in a manner that is consistent for all the passed slices. Note that
+/// the method replaces the uses of `candidateSlices` with the tiled and fused
+/// consumer value but does not delete the slice operations.
struct SCFFuseConsumerOfSliceResult {
- OpOperand *origConsumerOperand; // Original untiled consumer's operand.
- OpOperand
- *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
+ // Original untiled consumer operands.
+ SmallVector<OpOperand *> origConsumerOperands;
+ // Tiled and fused consumer operands.
+ SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
- MutableArrayRef<LoopLikeOpInterface> loops);
+tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
+ ArrayRef<Operation *> candidateSlices,
+ MutableArrayRef<LoopLikeOpInterface> loops);
/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 18981337742eb..9860b06348407 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -31,12 +31,16 @@ namespace tensor {
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
-/// Method to swap an `tensor.insert_slice` with its consumer when the
-/// consumer implements the `TilingInterface`.
+/// Method to swap an `tensor.insert_slice`s with its consumer when the
+/// consumer implements the `TilingInterface`. The size of `sliceOps` and
+/// `consumerOperands` is expected to be the same. Every entry in
+/// `consumerOperands` represents a use of the the corresponding
+/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is
+/// expected to be uses in the same consumer.
FailureOr<TilingResult>
-replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
- OffsetSizeAndStrideOpInterface sliceOp,
- OpOperand &consumerOp);
+replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
+ ArrayRef<tensor::InsertSliceOp> sliceOps,
+ ArrayRef<OpOperand *> consumerOperands);
//===----------------------------------------------------------------------===//
// Populate functions.
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index cdf3d01ce8a84..30d42993f23ec 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
InterfaceMethod<
/*desc=*/[{
Method to generate the tiled implementation of an operation that uses
- exactly a tile of the given operand.
+ exactly tiles of the given operands.
This method is required to allow operations to be "tiled and fused"
- with an (already tiled) producer. Given a tile of the producer, this
- method generates the tile of the consumer that uses exactly this
- produced tile. In some sense it is the "reverse" of
+ with an (already tiled) producer. Given tiles of the producer, this
+ method generates the tile of the consumer that uses exactly these
+ produced tiles. In some sense it is the "reverse" of
`generateResultTileValue`.
- - `operandNumber` is the result of the producer used by the consumer.
- - `offsets` is the offset of the slice of the producer result used by
- the tiled implementation of the consumer.
- - `sizes` is the size of the slice of the producer result used by the
+ - `operandNumbers` is the list of operands whose tiles are "producers".
+ - `allOffsets` is the offset of the slice of the producer used by the
+ tiled implementation of the consumer.
+ - `allSizes` is the size of the slice of the producer used by the
consumer.
- If it is illegal to fuse with a producer along the given operand for
+ If it is illegal to fuse with a producer along the given operand tiles for
an operation, the implementation should return a failure.
}],
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
- /*methodName=*/"getTiledImplementationFromOperandTile",
+ /*methodName=*/"getTiledImplementationFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
- "unsigned":$operandNumber,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
+ "::mlir::ArrayRef<unsigned>":$operandNumbers,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
tile of the operand.
This method is required to allow operations to be "tiled and fused"
- with an (already tiled) producer. Given a tile of an operand,
- returns the tile of the iteration space that uses this tile.
- - `operandNumber` is the result of the producer used by the consumer.
- - `offsets` is the offset of the slice of the producer result used by
+ with an (already tiled) producer. Given tiles of operands,
+ returns the tile of the iteration space that uses these tiles.
+ - `operandNumbers` is the list of operands whose tiles are "produced"
+ by the producer(s).
+ - `allOffsets` is the offset of the slice of the producers used by
the tiled implementation of the consumer.
- - `sizes` is the size of the slice of the producer result used by the
+ - `allSizes` is the size of the slice of the producers used by the
consumer.
- If it is illegal to fuse with a producer along the given operand for
- an operation, or if this mapping cannot be computed, the
- implementation should return a failure.
+ If it is illegal to fuse with the producer slices for an operation,
+ or if this mapping cannot be computed, the implementation should
+ return a failure.
Note that unlike the "tile consumer and fuse producer" case, the
"tile producer and fuse consumer" requires an additional method to get
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformation. It does not provide guarantees on whether such a
transformation is profitable.
- For most cases `getTiledImplementationFromOperandTile` could be a
- implemented using `getIterationDomainTileFromOperandTile` +
+ For most cases `getTiledImplementationFromOperandTiles` could be a
+ implemented using `getIterationDomainTileFromOperandTiles` +
`getTiledImplementation` methods.
}],
/*retType=*/"::llvm::LogicalResult",
- /*methodName=*/"getIterationDomainTileFromOperandTile",
+ /*methodName=*/"getIterationDomainTileFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
- "unsigned":$operandNumber,
- "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
- "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+ "::mlir::ArrayRef<unsigned>":$operandNumbers,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..86045b54075bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -147,55 +147,80 @@ struct LinalgOpTilingInterface
/// Utility to fetch the offsets and sizes when applied as per the indexing
/// map of the linalg op. This helps in fusing the linalg op as a consumer of
/// a given slice op.
- void
- getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- SmallVectorImpl<OpFoldResult> &mappedOffsets,
- SmallVectorImpl<OpFoldResult> &mappedSizes) const {
- unsigned numLoops = linalgOp.getNumLoops();
- auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
- mappedOffsets.resize(numLoops);
- mappedSizes.resize(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
- mappedOffsets[index] = value.offset;
- mappedSizes[index] = value.size;
+ static LogicalResult
+ getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
+ ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
+ SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
+ SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
+ DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
+
+ for (auto [indexingMap, offsets, sizes] :
+ llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
+ for (auto [resultExpr, offset, size] :
+ llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
+ if (!dimExpr)
+ continue;
+ unsigned position = dimExpr.getPosition();
+ auto it = mappedOffsets.find(position);
+ if (it != mappedOffsets.end()) {
+ OpFoldResult seenOffset = it->second;
+ OpFoldResult seenSize = mappedSizes.lookup(position);
+ if (seenOffset != offset || seenSize != size) {
+ return linalgOp->emitOpError(
+ "inconsistent iteration space mapping from offsets/sizes of "
+ "operands/results");
+ }
+ } else {
+ mappedOffsets[position] = offset;
+ mappedSizes[position] = size;
+ }
}
}
- for (const auto &&[index, value] :
- llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
- mappedOffsets[dimPosition] = offsets[index];
- mappedSizes[dimPosition] = sizes[index];
+
+ // Aggregate from the given operand offsets and sizes, or default to
+ // iteration space values.
+ SmallVector<Range> iterationDomain =
+ cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
+ mappedOffsetsVec.resize(iterationDomain.size());
+ mappedSizesVec.resize(iterationDomain.size());
+ for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
+ auto it = mappedOffsets.find(index);
+ if (it != mappedOffsets.end()) {
+ mappedOffsetsVec[index] = it->second;
+ mappedSizesVec[index] = mappedSizes.lookup(index);
+ continue;
+ }
+ mappedOffsetsVec[index] = domain.offset;
+ mappedOffsetsVec[index] = domain.size;
}
+ return success();
}
/// Method to return the position of the result tile computed by the tiled
/// operation.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);
- // Check that the indexing map used for the operand is a projected
- // permutation. This could be relaxed with a more general approach that can
- // map the offsets and sizes from the operand to iteration space tiles
- // (filling in full extent for dimensions not used to access the result).
- AffineMap indexingMap =
- linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
- if (!indexingMap.isProjectedPermutation()) {
- return op->emitError()
- << "unhandled get iter domain position when operand is not "
- "accessed using a permuted projection";
+ std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
+ iterationSpaceSizes;
+ SmallVector<AffineMap> indexingMaps =
+ llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+ OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+ return linalgOp.getMatchingIndexingMap(&opOperand);
+ });
+ if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
+ allSizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return failure();
}
-
- getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
- iterDomainOffsets, iterDomainSizes);
return success();
}
@@ -246,8 +271,13 @@ struct LinalgOpTilingInterface
"accessed using a permuted projection");
}
- getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
- iterDomainOffsets, iterDomainSizes);
+ SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
+ SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
+ auto status =
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
+ {allSizes}, iterDomainOffsets, iterDomainSizes);
+ (void)status;
+ assert(succeeded(status) && "unexpected error in offset calculation");
return success();
}
@@ -278,12 +308,13 @@ struct LinalgOpTilingInterface
/// Method to generate the tiled implementation of an operation from the tile
/// of the operand.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
- if (failed(getIterationDomainTileFromOperandTile(
- op, b, operandNumber, offsets, sizes, mappedOffsets,
+ if (failed(getIterationDomainTileFromOperandTiles(
+ op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
mappedSizes))) {
return failure();
}
@@ -750,13 +781,17 @@ struct PackOpTiling
/// Method to return the position of iteration domain tile computed by the
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
/// `resultSizes` only cover outer dimensions.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
- if (operandNumber != 0)
- return failure();
+ if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+ return op->emitOpError("unsupporeted operands for consumer fusion");
+
+ ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+ ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
// It is not trivial to infer dest tile from source tile if `packOp` has
@@ -817,11 +852,15 @@ struct PackOpTiling
}
/// Method to return the tiled implementation of tensor.pack as a consumer.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
- if (operandNumber != 0)
- return failure();
+ FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+ if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+ return op->emitOpError("unhandled operands for consumer fusion");
+
+ ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+ ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
Location loc = packOp.getLoc();
@@ -836,8 +875,8 @@ struct PackOpTiling
tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
- if (failed(getIterationDomainTileFromOperandTile(
- op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+ if (failed(getIterationDomainTileFromOperandTiles(
+ op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
outerDimSizes)))
return failure();
@@ -1095,12 +1134,20 @@ struct UnPackOpTiling
/// Method to return the position of iteration domain tile computed by the
/// tiled operation.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
+ if (operandNumbers.size() != 1) {
+ return op->emitOpError("unable to handle multiple operands");
+ }
auto unPackOp = cast<UnPackOp>(op);
+ unsigned operandNumber = operandNumbers[0];
+ ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+ ArrayRef<OpFoldResult> sizes(allSizes[0]);
+
// If the operand tile is the dest, then no adjustment is needed.
if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
resultOffsets = llvm::to_vector(offsets);
@@ -1154,10 +1201,17 @@ struct UnPackOpTiling
}
/// Method to return the tiled implementation of tensor.unpack as a consumer.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+ if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
+ return op->emitOpError("unhandled operands for consumer fusion");
+ }
auto unPackOp = cast<UnPackOp>(op);
+ ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+ ArrayRef<OpFoldResult> sizes(allSizes[0]);
+
// tensor.unpack op is fusible (as a consumer) only if inner dims are not
// tiled.
int64_t numTiles = unPackOp.getInnerDimsPos().size();
@@ -1172,8 +1226,8 @@ struct UnPackOpTiling
// Fetch offset/size for creating the slice of the dest operand of
// unpack op.
SmallVector<OpFoldResult> outputOffsets, outputSizes;
- if (failed(getIterationDomainTileFromOperandTile(
- op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
+ if (failed(getIterationDomainTileFromOperandTiles(
+ op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
outputSizes)))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 3f29dd3ac5e48..4e19c2770f722 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1977,53 +1977,118 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
/// A utility to fetch an untiled consumer of
/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
- MutableArrayRef<LoopLikeOpInterface> loops) {
+static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
+ RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected empty loops");
- if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
- } else if (auto parallelInsertSlice =
- dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
- } else {
- return failure();
+ assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
+ SmallVector<OpOperand *> fusedOperands;
+ for (auto sliceOp : sliceOps) {
+ FailureOr<OpOperand *> fusedOperand =
+ TypeSwitch<Operation *, FailureOr<OpOperand *>>(sliceOp)
+ .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+ [&](auto op) {
+ return getUntiledConsumerFromSlice(rewriter, op, loops);
+ })
+ .Default([](Operation *op) {
+ return op->emitOpError("unhandled slice type");
+ });
+ if (failed(fusedOperand)) {
+ return failure();
+ }
+ if (!fusedOperands.empty() &&
+ fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
+ return fusedOperands.front()->getOwner()->emitOpError(
+ "all candidate slices must be to the same consumer");
+ }
+ fusedOperands.push_back(fusedOperand.value());
+ }
+ return fusedOperands;
+}
+
+template <typename InsertSliceOpTy>
+static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
+ InsertSliceOpTy sliceOp);
+
+template <>
+tensor::InsertSliceOp
+cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
+ tensor::InsertSliceOp insertSliceOp) {
+ return cast<tensor::InsertSliceOp>(
+ rewriter.clone(*insertSliceOp.getOperation()));
+}
+
+template <>
+tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
+ RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
+ return rewriter.create<tensor::InsertSliceOp>(
+ insertSliceOp->getLoc(), insertSliceOp.getSource(),
+ insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
+ insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+}
+
+static SmallVector<tensor::InsertSliceOp>
+cloneAsInsertSlices(RewriterBase &rewriter,
+ ArrayRef<Operation *> candidateSlices) {
+ assert(!candidateSlices.empty() &&
+ "unexpected empty list of slices to clone");
+ SmallVector<tensor::InsertSliceOp> clonedSlices;
+ for (auto sliceOp : candidateSlices) {
+ TypeSwitch<Operation *>(sliceOp)
+ .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+ [&](auto op) {
+ auto clonedOp = cloneAsInsertSlice(rewriter, op);
+ clonedSlices.push_back(clonedOp);
+ })
+ .Default([&](Operation *op) {
+ // Assert here assuming this has already been checked.
+ assert(0 && "unexpected slice type while cloning as insert slice");
+ });
}
+ return clonedSlices;
}
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(
- RewriterBase &rewriter, Operation *candidateSliceOp,
+mlir::scf::tileAndFuseConsumerOfSlices(
+ RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops) {
+ if (candidateSlices.empty()) {
+ return emitError(rewriter.getUnknownLoc(),
+ "no candidate slices provided for consumer fusion");
+ }
// Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case.
if (loops.empty()) {
- return candidateSliceOp->emitOpError(
+ return candidateSlices.front()->emitOpError(
"cannot call tile and fuse consumer with an empty loop nest");
}
- if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
- candidateSliceOp))
- return failure();
+
+ if (!(llvm::all_of(
+ candidateSlices,
+ [](Operation *op) { return isa<tensor::InsertSliceOp>(op); }) ||
+ llvm::all_of(candidateSlices, [](Operation *op) {
+ return isa<tensor::ParallelInsertSliceOp>(op);
+ }))) {
+ return candidateSlices.front()->emitOpError(
+ "candidates slices need to be all `tensor.extract_slice`s or "
+ "`tensor.parallel_insert_slice`s");
+ }
// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
- FailureOr<OpOperand *> maybeConsumerOpOperand =
- getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
- if (failed(maybeConsumerOpOperand)) {
- return rewriter.notifyMatchFailure(candidateSliceOp,
- "could not fetch consumer to fuse");
- }
- OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
- Operation *consumerOp = consumerOpOperand->getOwner();
- unsigned operandNumber = consumerOpOperand->getOperandNumber();
- unsigned resultNumber = 0;
- if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
- resultNumber = producerResult.getResultNumber();
- } else {
- return rewriter.notifyMatchFailure(
- consumerOp, "consumer op's operand doesn't seem to be an OpResult");
+ SmallVector<OpOperand *> consumerOpOperands;
+ Operation *consumerOp;
+ {
+ FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
+ getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+ if (failed(maybeConsumerOpOperand)) {
+ return rewriter.notifyMatchFailure(candidateSlices.front(),
+ "could not fetch consumer to fuse");
+ }
+ std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
+ consumerOp = consumerOpOperands.front()->getOwner();
}
LoopLikeOpInterface outerMostLoop = loops.front();
@@ -2043,16 +2108,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
if (!dstOp)
return rewriter.notifyMatchFailure(consumerOp,
"consumer op is not DPS operation");
- SmallVector<Value> dpsInits =
- llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
+ if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
+ return dstOp.isDpsInit(opOperand);
+ })) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
- SmallVector<Value> newInits = dpsInits;
-
- Location loc = outerMostLoop->getLoc();
+ SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
// 3. Move the whole loop structure right before firstUserOfLoop, the
// dominance should be already ensured by `checkAssumptionForLoop`.
@@ -2067,43 +2130,52 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
- tensor::InsertSliceOp clonedInsertSliceOp;
if (auto sliceOp =
- dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
rewriter.setInsertionPoint(newForallOp.getTerminator());
- clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else {
- rewriter.setInsertionPoint(candidateSliceOp);
- clonedInsertSliceOp =
- cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
+ rewriter.setInsertionPoint(candidateSlices.front());
}
+ // 5.a. Clone all the candidate slices as equivalent insert slice ops.
+ SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
+ cloneAsInsertSlices(rewriter, candidateSlices);
- // 5.a. Clone consumer op.
+ // 5.b. Clone consumer op.
auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
+ SmallVector<unsigned> operandNumbers =
+ llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
+ return opOperand->getOperandNumber();
+ });
+ SmallVector<OpOperand *> clonedOpFusedOperandsList =
+ llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
+ return &clonedConsumerOp->getOpOperand(operandNum);
+ });
- // 5.b. Replace all uses of the loop result with the result of the cloned
+ // 5.c. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice.
- OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
- operandToReplace.set(clonedInsertSliceOp.getResult());
+ for (auto [operandToReplace, clonedSliceOp] :
+ llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
+ operandToReplace->set(clonedSliceOp.getResult());
+ }
});
// 6. Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op.
- auto ossSliceOp =
- cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
FailureOr<TilingResult> tileAndFuseResult =
- tensor::replaceInsertSliceWithTiledConsumer(
- rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+ tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
+ clonedOpFusedOperandsList);
if (failed(tileAndFuseResult)) {
return failure();
}
+
auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
- rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
- clonedInsertSliceOp.getSource());
+ for (auto [operandNum, clonedSliceOp] :
+ llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
+ rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
+ clonedSliceOp.getSource());
+ }
// 7. Reconstruct [nested] loop with new inits.
YieldTiledValuesFn newYieldValuesFn =
@@ -2115,14 +2187,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// 8. Set inner insertPoint right before tiled consumer op.
innerRewriter.setInsertionPoint(tiledConsumerOp);
- SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+ SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
+ for (auto candidateSliceOp : clonedInsertSlices) {
+ SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
- // 9. Check all insert stride is 1.
- if (!llvm::all_of(strides, isOneInteger)) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
+ // 9. Check all insert stride is 1.
+ allOffsets.emplace_back(std::move(offsets));
+ allSizes.emplace_back(std::move(sizes));
}
// 10. Try to get iter domain position from input position. Use
@@ -2132,8 +2204,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// tiledConsumerOp could lead to some chained unnecessary extra index
// computation.
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
+ rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
iterDomainSizes))) {
return rewriter.notifyMatchFailure(
clonedConsumerOp,
@@ -2209,10 +2281,13 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// 16. Need to erase the old scf loop and the cloned consumer op.
rewriter.eraseOp(clonedConsumerOp);
+ SmallVector<OpOperand *> tiledAndFusedOpOperands =
+ llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
+ return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
+ });
return scf::SCFFuseConsumerOfSliceResult{
- consumerOpOperand,
- &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
- tileAndFuseResult->tiledOps};
+ std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
+ std::move(tileAndFuseResult->tiledOps)};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 6f33f9b55ceb6..10831e790d739 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -39,21 +39,48 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
return *tiledResult;
}
-FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
- OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
- OpOperand &consumer) {
- auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
+FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
+ OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
+ ArrayRef<OpOperand *> consumerOperands) {
+ if (sliceOps.empty()) {
+ return emitError(builder.getUnknownLoc(),
+ "expected candidate slices list to be non-empty");
+ }
+ if (sliceOps.size() != consumerOperands.size()) {
+ return sliceOps.front()->emitOpError(
+ "expected as many operands as the number of slices passed");
+ }
+ auto consumerOp =
+ dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
if (!consumerOp)
return failure();
+ for (auto opOperand : consumerOperands.drop_front()) {
+ if (opOperand->getOwner() != consumerOp) {
+ return consumerOp->emitOpError(
+ "expected all consumer operands to be from the same operation");
+ }
+ }
- // `TilingInterface` currently only supports strides being 1.
- if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
- return failure();
+ auto consumerOperandNums = llvm::map_to_vector(
+ consumerOperands, [](OpOperand *opOperand) -> unsigned {
+ return opOperand->getOperandNumber();
+ });
+ SmallVector<SmallVector<OpFoldResult>> allOffsets;
+ SmallVector<SmallVector<OpFoldResult>> allSizes;
+ for (auto sliceOp : sliceOps) {
+
+ // `TilingInterface` currently only supports strides being 1.
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
+ return failure();
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
+ allOffsets.emplace_back(std::move(offsets));
+ allSizes.emplace_back(std::move(sizes));
+ }
FailureOr<TilingResult> tiledResult =
- consumerOp.getTiledImplementationFromOperandTile(
- builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes());
+ consumerOp.getTiledImplementationFromOperandTiles(
+ builder, consumerOperandNums, allOffsets, allSizes);
if (failed(tiledResult))
return failure();
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 77e52946b830f..569f12d5b615e 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -620,3 +620,292 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+ %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+ %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+ %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+ %generic:2 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %0 = arith.mulf %b0, %b1 : f32
+ %1 = arith.addf %b0, %b2 : f32
+ linalg.yield %0, %1 : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ %empty = tensor.empty(%dim0) : tensor<?xf32>
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion_1(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+// CHECK: %[[TILESIZE:.+]] = affine.min
+// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+// CHECK-DAG: %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+// CHECK: %[[FUSED:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[GENERIC]]#0, %[[GENERIC]]#1 :
+// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+// CHECK: return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+ : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// Check that when the given operand tiles are incosistent, tiling fails.
+
+func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+ %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+ %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+ %generic0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.mulf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+ %generic1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0: f32
+ } -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ %empty = tensor.empty(%dim0) : tensor<?xf32>
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion_2(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+// CHECK: %[[TILESIZE:.+]] = affine.min
+// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+// CHECK: %[[GENERIC0:.+]] = linalg.generic
+// CHECK: %[[GENERIC1:.+]] = linalg.generic
+// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+// CHECK: %[[FUSED:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
+// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+// CHECK: return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+ : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
+ %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+ %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+ %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+ shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) {
+ %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+ %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+ %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+ : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+ %generic0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.mulf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32>
+ %generic1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0: f32
+ } -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+ // expected-error @below {{unable to apply consumer fusion along these operands since the iteration spaces are inconsistent}}
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+ : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func @multi_slice_fusion_2(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+// CHECK: %[[TILESIZE:.+]] = affine.min
+// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+// CHECK: %[[GENERIC0:.+]] = linalg.generic
+// CHECK: %[[GENERIC1:.+]] = linalg.generic
+// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+// CHECK: %[[FUSED:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
+// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+// CHECK: return %[[RESULT]]#2
+
+// -----
+
+func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
+ %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+ %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+ %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+ shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+ %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+ %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+ : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+ %generic0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.mulf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ %init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+ %generic1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0: f32
+ } -> tensor<?x?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ }
+ }
+ %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+ // expected-error @below {{unable to apply consumer fusion along these operands since the iteration spaces are inconsistent}}
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+ : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 9971f0cde4ed2..3c810f6f4accf 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -168,29 +168,30 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
/// Apply fusing of consumer transformation to all payload ops and store both
/// the original consumer operation as well as the fused consumer operation.
-template <typename Range>
static LogicalResult applyFuseConsumer(
- RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
- MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse,
- TransformResults &transformResults) {
+ RewriterBase &rewriter, Operation *transformOp,
+ ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
+ uint32_t numConsumerToFuse, TransformResults &transformResults) {
SmallVector<Operation *> originalConsumerOps;
SmallVector<Operation *> fusedConsumerOps;
- for (Operation *target : payloadOps) {
- rewriter.setInsertionPoint(target);
+ rewriter.setInsertionPoint(slices.front());
- while (numConsumerToFuse--) {
- FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
- scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
+ while (numConsumerToFuse--) {
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+ scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
- if (failed(fuseConsumerResults))
- return failure();
+ if (failed(fuseConsumerResults))
+ return failure();
- // Report back the relevant handles to the transform op.
- originalConsumerOps.push_back(
- fuseConsumerResults->origConsumerOperand->getOwner());
- fusedConsumerOps.push_back(
- fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+ // Report back the relevant handles to the transform op.
+ for (OpOperand *origConsumerOperand :
+ fuseConsumerResults->origConsumerOperands) {
+ originalConsumerOps.push_back(origConsumerOperand->getOwner());
+ }
+ for (OpOperand *tiledAndFusedConsumerOperand :
+ fuseConsumerResults->tiledAndFusedConsumerOperands) {
+ fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
}
}
@@ -203,6 +204,12 @@ DiagnosedSilenceableFailure
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
+ SmallVector<Operation *> slices;
+ for (auto op : getTargets()) {
+ auto sliceOp = *state.getPayloadOps(op).begin();
+ slices.push_back(sliceOp);
+ }
+
SmallVector<LoopLikeOpInterface> loops;
for (auto op : llvm::reverse(getLoops())) {
auto loopLikeOp =
@@ -212,16 +219,16 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
}
loops.push_back(loopLikeOp);
}
- LogicalResult result = applyFuseConsumer(
- rewriter, getOperation(), state.getPayloadOps(getTarget()), loops,
- getNumConsumerToFuse(), transformResults);
+ LogicalResult result =
+ applyFuseConsumer(rewriter, getOperation(), slices, loops,
+ getNumConsumerToFuse(), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
void transform::TestFuseConsumerOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getTargetMutable(), effects);
+ consumesHandle(getTargetsMutable(), effects);
consumesHandle(getLoopsMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 98f7145c99cb1..3c09082e192ea 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -50,7 +50,8 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
}
def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
@@ -59,14 +60,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
}];
let arguments = (ins
- TransformHandleTypeInterface:$target,
+ Variadic<TransformHandleTypeInterface>:$targets,
Variadic<TransformHandleTypeInterface>:$loops,
DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
let results = (outs TransformHandleTypeInterface:$consumer,
TransformHandleTypeInterface:$fused_consumer);
let assemblyFormat = [{
- $target `in` `(` $loops `)`
+ $targets `in` `(` $loops `)`
(`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)?
attr-dict `:` functional-type(operands, results)
}];
More information about the Mlir-commits
mailing list