[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
Abhishek Varma
llvmlistbot at llvm.org
Wed May 22 01:15:57 PDT 2024
https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/88712
>From 3b493574f946bdd2a792d8e36753a1566752d93d Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Sat, 16 Mar 2024 09:12:12 +0800
Subject: [PATCH 01/10] [mlir][linalg] Enable fuse consumer
This patch adds support for consumer fusion to the tiling interface.
- Add interface method 'getIterationDomainTileFromOperandTile' to tiling
interface which get iteration domain position from operand position.
- Add interface method 'getTiledImplementationFromOperandTile' to tiling
interface which generate tiled implementation according to operand position.
---
.../mlir/Interfaces/TilingInterface.td | 67 ++++++++++-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 +-
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 106 +++++++++++++-----
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 12 +-
4 files changed, 149 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c2424..84f7dec2f4003 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
The method returns the operation that is the tiled
implementation.
}],
- /*retType=*/"FailureOr<TilingResult>",
+ /*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementation",
/*args=*/(ins
"OpBuilder &":$b,
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
by the tiled implementation. Expects the same `offsets` and `sizes` as
used to obtain the tiled implementation of the operation.
}],
- /*retType=*/"LogicalResult",
+ /*retType=*/"::mlir::LogicalResult",
/*methodName=*/"getResultTilePosition",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$resultNumber,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
- "SmallVector<OpFoldResult> &":$resultOffsets,
- "SmallVector<OpFoldResult> &":$resultSizes),
+ "SmallVectorImpl<OpFoldResult> &":$resultOffsets,
+ "SmallVectorImpl<OpFoldResult> &":$resultSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return the position of iteration domain tile computed by the
+ tiled operation.
+ }],
+ /*retType=*/"::mlir::LogicalResult",
+ /*methodName=*/"getIterationDomainTileFromOperandTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
iteration space).
- `sizes` provides the size of the tile.
}],
- /*retType=*/"FailureOr<TilingResult>",
+ /*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"generateResultTileValue",
/*args=*/(ins
"OpBuilder &":$b,
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate the tiled implementation of an operation from
+ operand tile position.
+
+ Generates the IR that computes the tiled implementation of an
+ operation from operand tile. The `offsets` and `sizes`
+ describe the tile of the operand required. This is different from
+ `getTiledImplementation` which generates the tiled
+ implementation of the operation given a tile of the
+ iteration space. This method generates a tiled
+ implementation of the operation based on the tile of the
+ operand required. This method enables consumer fusion by using
+ tile and fuse. The method returns failure if the operation
+ can't be tiled to generate the operand tile. In practical terms
+ this implies it cannot be tiled and fused with its producers.
+
+ - `offsets` provides the offset of the tile in the coordinate system
+ of the original iteration space, i.e., if an iteration space
+ dimension had non-zero offset, it must be included in the offset
+ provided here (as opposed to zero-based offset "relative" to the
+ iteration space).
+ - `sizes` provides the size of the tile.
+ }],
+ /*retType=*/"FailureOr<::mlir::TilingResult>",
+ /*methodName=*/"getTiledImplementationFromOperandTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult>":$offsets,
+ "ArrayRef<OpFoldResult>":$sizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Generates the scalar implementation of the operation.
@@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformations are done, this method can be used to lower to scalar
code that can then be lowered to LLVM or SPIR-V dialects.
}],
- /*retType=*/"LogicalResult",
+ /*retType=*/"::mlir::LogicalResult",
/*methodName=*/"generateScalarImplementation",
/*args=*/(ins
"OpBuilder &":$b,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..03716eaaa6358 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2454,8 +2454,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
LogicalResult SoftmaxOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) {
+ ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5..71e9c3771dcde 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
}));
}
- // Instantiate the tiled implementation of the operation.
+ /// Instantiate the tiled implementation of the operation.
FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
- // Return the details of the output tile generated by the tiled
- // implementation.
+ void
+ getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &mappedOffsets,
+ SmallVectorImpl<OpFoldResult> &mappedSizes) const {
+ unsigned numLoops = linalgOp.getNumLoops();
+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+ mappedOffsets.resize(numLoops);
+ mappedSizes.resize(numLoops);
+ if (!indexingMap.isPermutation()) {
+ SmallVector<Range> iterationDomain =
+ tilingInterfaceOp.getIterationDomain(b);
+ for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
+ mappedOffsets[index] = value.offset;
+ mappedSizes[index] = value.size;
+ }
+ }
+ for (const auto &&[index, value] :
+ llvm::enumerate(indexingMap.getResults())) {
+ unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
+ mappedOffsets[dimPosition] = offsets[index];
+ mappedSizes[dimPosition] = sizes[index];
+ }
+ }
+
+ /// Return the details of the output tile generated by the tiled
+ /// implementation.
+ LogicalResult getIterationDomainTileFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+ SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+
+ // Check that the indexing map used for the operand is a projected
+ // permutation. This could be relaxed with a more general approach that can
+ // map the offsets and sizes from the operand to iteration space tiles
+ // (filling in full extent for dimensions not used to access the result).
+ AffineMap indexingMap =
+ linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+ if (!indexingMap.isProjectedPermutation()) {
+ return emitError(op->getLoc(),
+ "unhandled get iter domain position when operand is not "
+ "accessed using a permuted projection");
+ }
+
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+ iterDomainOffsets, iterDomainSizes);
+ return success();
+ }
+
+ /// Return the details of the output tile generated by the tiled
+ /// implementation.
LogicalResult
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
return success();
}
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return emitError(
+ op->getLoc(),
+ "unable to obtain the iter domain position of the operation.");
+ }
+ return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+ mappedSizes);
+ }
+
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
-
- auto numLoops = linalgOp.getNumLoops();
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+ mappedOffsets, mappedSizes);
auto tilingInterfaceOp = cast<TilingInterface>(op);
- SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
- iterationTileSizes(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &range : llvm::enumerate(iterationDomain)) {
- iterationTileOffsets[range.index()] = range.value().offset;
- iterationTileSizes[range.index()] = range.value().size;
- }
- }
- for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition =
- cast<AffineDimExpr>(resultExpr.value()).getPosition();
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
- iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
- }
-
FailureOr<TilingResult> tilingResult =
- tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
- iterationTileSizes);
+ tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+ if (failed(tilingResult))
+ return failure();
+
if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index d25efcf50ec56..296c5fc7a5c2b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
// The iteration domain is over outer dimensions of packed layout. In this
// context, the outer dimensions of `resultOffsets` are `offsets`. The
// inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
resultOffsets = llvm::to_vector(offsets);
resultSizes = llvm::to_vector(sizes);
return success();
>From ec05f4cb6a58d93fdf26c5fa5b208978c83fc1bf Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 10 Apr 2024 10:41:46 +0000
Subject: [PATCH 02/10] [MLIR][SCF] Add an API to fuse consumer to a producer
within scf loop
-- This commit adds an API to fuse consumer to a producer within
scf.for/scf.forall loop.
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
.../SCF/Transforms/TileUsingInterface.h | 13 +
.../Dialect/Tensor/Transforms/Transforms.h | 13 +-
.../SCF/Transforms/TileUsingInterface.cpp | 429 ++++++++++++++++++
.../SwapExtractSliceWithProducerPatterns.cpp | 23 +
.../tile-and-fuse-consumer.mlir | 258 +++++++++++
.../TestTilingInterfaceTransformOps.cpp | 51 +++
.../TestTilingInterfaceTransformOps.td | 19 +
7 files changed, 805 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..c744a17ae9191 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -14,6 +14,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include <deque>
@@ -239,6 +240,18 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer,
const SCFTileAndFuseOptions &options);
+/// Fuse the consumer of the source of `candidateSliceOp` by computing the
+/// required slice of the consumer in-place. Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
+/// value but does not delete the slice operation.
+struct SCFFuseConsumerOfSliceResult {
+ Operation *origConsumer; // Original untiled consumer.
+ Operation *tiledAndFusedConsumer; // Tiled and fused consumer op.
+ SmallVector<Operation *> tiledOps;
+};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+
/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
FailureOr<SmallVector<scf::ForOp>>
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index e8a09c4741043..98447cf62900d 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
@@ -22,7 +23,7 @@ namespace tensor {
// Patterns
//===----------------------------------------------------------------------===//
-/// Pattern to swap an `tensor.extract_slice` with its producer when the
+/// Method to swap an `tensor.extract_slice` with its producer when the
/// producer implements the `TilingInterface`. The pattern itself does not
/// provide a mechanism to control where the application happens. With use of
/// transform dialect that control is done within the transform dialect. Other
@@ -30,6 +31,16 @@ namespace tensor {
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
+/// Method to swap an `tensor.insert_slice` with its consumer when the
+/// consumer implements the `TilingInterface`. The pattern itself does not
+/// provide a mechanism to control where the application happens. With use of
+/// transform dialect that control is done within the transform dialect. Other
+/// use cases can inherit from this pattern and add necessary controls.
+FailureOr<TilingResult>
+replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
+ OffsetSizeAndStrideOpInterface sliceOp,
+ OpOperand &consumerOp);
+
//===----------------------------------------------------------------------===//
// Populate functions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69d..d4147e2b0602f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -16,9 +16,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -1100,6 +1102,433 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+ Value::use_range uses = result.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return failure();
+ }
+ OpOperand &operandUse = (*uses.begin());
+ Operation *userOp = operandUse.getOwner();
+ if (!isa<scf::YieldOp>(userOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Expected scf.yield to be the only user, but got -> "
+ << (*userOp));
+ return failure();
+ }
+ if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+ LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+ "be in the same block\n");
+ return failure();
+ }
+ return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1. tensor.insert_slice has scf.yield as its only user.
+/// 2. scf.for's corresponding result has only one use.
+static OpOperand *
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+ Value sliceResult = candidateSliceOp.getResult();
+ if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+ return nullptr;
+ }
+ // Step 1. Fetch the corresponding output.
+ OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+ unsigned resultNumber = yieldOpOperand.getOperandNumber();
+ // Step 2. Check containing op is scf.for.
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp) {
+ return nullptr;
+ }
+ Value resultingValue = forOp->getResult(resultNumber);
+
+ // Step 3. Check resulting value of scf.for has exactly one use.
+ if (!llvm::hasSingleElement(resultingValue.getUses())) {
+ return nullptr;
+ }
+
+ // Step 4. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp)) {
+ return nullptr;
+ }
+ if (containingOp->getBlock() != consumerOp->getBlock()) {
+ return nullptr;
+ }
+ return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+ tensor::InsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of scf.for for the result yielded by
+ // tensor.insert_slice.
+ OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
+ if (!consumerOpOperand) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = consumerOpOperand->getOwner();
+ unsigned operandNumber = consumerOpOperand->getOperandNumber();
+ unsigned resultNumber =
+ cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+
+ auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(candidateSliceOp);
+
+ // 2. Check consumer is not using scf.for's output as init.
+ auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.for as init is not supported");
+ }
+
+ Location loc = forOp.getLoc();
+ SmallVector<Value> newOuts(forOp.getInits());
+ newOuts.append(dpsInits);
+
+ // 3. Create new scf.for op.
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+ forOp.getUpperBound(),
+ forOp.getStep(), newOuts);
+ // 4. Move the loop body to the new op.
+ Block *loopBody = forOp.getBody();
+ Block *newLoopBody = newforOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ // 5.a. Clone consumer after the cloned tensor.insert_slice op.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ auto newForOpBlockArgsForConsumerDest =
+ newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
+ auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // 5.b. Replace all uses of the loop result with the result of the cloned
+ // tensor.insert_slice.
+ OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+ rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+ operandToReplace.set(candidateSliceOp.getResult());
+ });
+
+ // 6 - Perform tiling of the cloned consumer.
+ rewriter.setInsertionPointAfter(clonedConsumerOp);
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter,
+ cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
+ clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return rewriter.notifyMatchFailure(clonedConsumerOp,
+ "failed to tile consumer op: ");
+ }
+
+ // 7 - Extract offset/sizes/strides required to create the tensor.insert_slice
+ // for each result of the consumer.
+ SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+ // 8. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
+ // 9. Try to get iter domain position from input position.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ rewriter.setInsertionPointAfter(clonedConsumerOp);
+ if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp, "can't get iter domain position from input position");
+ }
+
+ // 10. Try to fetch the offset and size for all results of the cloned
+ // consumer. This would then be used to form the corresponding
+ // tensor.insert_slice later.
+ unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+ for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+ if (failed(clonedConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
+
+ // 11. Fix terminator.
+ scf::YieldOp oldTerminatorOp =
+ cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+ SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+ rewriter.setInsertionPointAfter(oldTerminatorOp);
+ MutableArrayRef<BlockArgument> bbArgs = newforOp.getBody()->getArguments();
+ for (auto [idx, v] :
+ llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+ rewriter.getIndexAttr(1));
+ Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ candidateSliceOp->getLoc(), v,
+ bbArgs[1 + forOp.getInits().size() + idx], resultOffsets[idx],
+ resultSizes[idx], strides);
+ newYieldOperands.push_back(newInsertSliceOp);
+ }
+ rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+ rewriter.eraseOp(oldTerminatorOp);
+
+ // 12. Replace the result of scf.for and consumer op.
+ for (auto &&[oldResult, newResult] :
+ llvm::zip_first(forOp.getResults(), newforOp.getResults())) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
+ }
+
+ for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ oldValue, newforOp->getResult(forOp.getInits().size() + index));
+ }
+
+ // 13. Need to erase the old scf.for and the cloned consumer op.
+ rewriter.eraseOp(forOp);
+ rewriter.eraseOp(clonedConsumerOp);
+
+ return scf::SCFFuseConsumerOfSliceResult{
+ consumerOp, tileAndFuseResult->tiledOps[0], {}};
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static OpOperand *
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // Step 1. Fetch the corresponding output
+ Value sliceDest = candidateSliceOp.getDest();
+ auto iterArg = cast<BlockArgument>(sliceDest);
+ Operation *containingOp = iterArg.getOwner()->getParentOp();
+ // Step 2. Check that the containing op is scf.forall.
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ return nullptr;
+ }
+ Value resultingValue =
+ forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+ // Step 3. Check resulting value of scf.forall has exactly one use.
+ Value::use_range uses = resultingValue.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ return nullptr;
+ }
+
+ // Step 4. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.forall, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp)) {
+ return nullptr;
+ }
+ if (containingOp->getBlock() != consumerOp->getBlock()) {
+ return nullptr;
+ }
+ return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+ RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the dest.
+ OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
+ if (!consumerOpOperand) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = consumerOpOperand->getOwner();
+ unsigned operandNumber = consumerOpOperand->getOperandNumber();
+ unsigned resultNumber =
+ cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ // Using candidateSliceOp->getParentOp() because we have the following case :-
+ // scf.forall.in_parallel {
+ // tensor.parallel_insert_slice ...
+ // }
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+ Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+ auto forallOp = cast<scf::ForallOp>(containingOp);
+
+ // 2. Check consumer is not using scf.forall's output as init.
+ auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.forall as init is not supported");
+ }
+
+ Location loc = forallOp.getLoc();
+ // 3. Create new scf.forall op.
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.append(dpsInits);
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+ // 4. Move the loop body to the new op.
+ rewriter.eraseOp(newforallOp.getTerminator());
+ Block *loopBody = forallOp.getBody();
+ Block *newLoopBody = newforallOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ // 5.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ auto newForOpBlockArgsForConsumerDest =
+ newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
+ auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // 5.b. Replace all uses of the scf.forall's result use in the consumer with
+ // the source of the cloned tensor.parallel_insert_slice.
+ OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+ rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+ operandToReplace.set(candidateSliceOp.getSource());
+ });
+
+ // 6. Perform tiling of the cloned consumer.
+ rewriter.setInsertionPoint(newforallOp.getTerminator());
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter,
+ cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
+ clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return rewriter.notifyMatchFailure(clonedConsumerOp,
+ "failed to tile consumer op: ");
+ }
+
+ // 7. Extract offset/sizes/strides required to create the
+ // tensor.parallel_insert_slice for each result of the consumer.
+ SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+ // 8. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
+ // 9. Try to get iter domain position from input position.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
+ if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp, "can't get iter domain position from input position");
+ }
+
+ // 10. Try to fetch the offset and size for all results of the cloned
+ // consumer. This would then be used to form the corresponding
+ // tensor.parallel_insert_slice later.
+ unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+ for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+ if (failed(clonedConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
+
+ // 11. Fix terminator.
+ scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+ rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+ Location firstYieldOpLoc =
+ (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+ MutableArrayRef<BlockArgument> bbArgs = newforallOp.getBody()->getArguments();
+ for (auto [idx, v] :
+ llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+ rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ firstYieldOpLoc, v,
+ bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+ resultOffsets[idx], resultSizes[idx], strides);
+ }
+
+ // 12. Replace the result of scf.forall and consumer op.
+ for (auto &&[oldResult, newResult] :
+ llvm::zip_first(forallOp.getResults(), newforallOp.getResults())) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
+ }
+
+ for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ oldValue, newforallOp->getResult(forallOp.getOutputs().size() + index));
+ }
+
+ // 13. Need to erase the old scf.forall and cloned consumer.
+ rewriter.eraseOp(forallOp);
+ rewriter.eraseOp(clonedConsumerOp);
+
+ return scf::SCFFuseConsumerOfSliceResult{
+ consumerOp, tileAndFuseResult->tiledOps[0], {}};
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+ Operation *candidateSliceOp) {
+ if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+ return tileAndFuseConsumerOfSliceSCFFor(rewriter, sliceOp);
+ } else if (auto sliceOp =
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+ return tileAndFuseConsumerOfSliceSCFForall(rewriter, sliceOp);
+ } else {
+ return failure();
+ }
+}
+
//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 40d79c2053817..858adfc436164 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -40,3 +40,26 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
return *tiledResult;
}
+
+FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
+ OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
+ OpOperand &consumer) {
+ auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
+ if (!consumerOp)
+ return failure();
+
+ // `TilingInterface` currently only supports strides being 1.
+ if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+ return !isConstantIntValue(ofr, 1);
+ }))
+ return failure();
+
+ FailureOr<TilingResult> tiledResult =
+ consumerOp.getTiledImplementationFromOperandTile(
+ builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes());
+ if (failed(tiledResult))
+ return failure();
+
+ return *tiledResult;
+}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
new file mode 100644
index 0000000000000..3d60e32bfa0cc
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -0,0 +1,258 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+module {
+ func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32xf32>
+ %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+ scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+ }
+ %in_operand_2 = tensor.empty() : tensor<64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64xf32>
+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
+ return %2 : tensor<64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %yield
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_for(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %0 = tensor.empty() : tensor<64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[INSERT_MAT]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT]] :
+// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+module {
+ func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ }
+ }
+ %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ return %2 : tensor<64x64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+ : (!transform.any_op)
+ -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %first_slice_op
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_forall(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
+// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+module {
+ func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32xf32>
+ %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+ scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+ }
+ %in_operand_2 = tensor.empty() : tensor<64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64xf32>
+ %out_operand_4 = tensor.empty() : tensor<64xf32>
+ %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.subf %out_0, %13 : f32
+ %15 = arith.addf %out_1, %in : f32
+ linalg.yield %14, %15 : f32, f32
+ } -> (tensor<64xf32>, tensor<64xf32>)
+ return %2#1 : tensor<64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %yield
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %0 = tensor.empty() : tensor<64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0)
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[INSERT_MAT]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+// CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+// CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#3 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ }
+ }
+ %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+ %out_operand_4 = tensor.empty() : tensor<64x64xf32>
+ %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64x64xf32>, tensor<64x64xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.subf %out_0, %13 : f32
+ %15 = arith.addf %out_1, %in : f32
+ linalg.yield %14, %15 : f32, f32
+ } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+ return %2#1 : tensor<64x64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+ : (!transform.any_op)
+ -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %first_slice_op
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
+// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#3 :
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 335db1a61f476..181d8ebc68f9e 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -160,6 +160,57 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
: DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// TestFuseConsumerOp
+//===----------------------------------------------------------------------===//
+
+/// Apply fusing of consumer transformation to all payload ops and store both
+/// the original consumer operation as well as the fused consumer operation.
+template <typename Range>
+static LogicalResult
+applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
+ Range &&payloadOps, TransformResults &transformResults) {
+ SmallVector<Operation *> originalConsumerOps;
+ SmallVector<Operation *> fusedConsumerOps;
+
+ for (Operation *target : payloadOps) {
+ rewriter.setInsertionPoint(target);
+
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+ scf::tileAndFuseConsumerOfSlice(rewriter, target);
+
+ if (failed(fuseConsumerResults))
+ return failure();
+
+ // Report back the relevant handles to the transform op.
+ originalConsumerOps.push_back(fuseConsumerResults->origConsumer);
+ fusedConsumerOps.push_back(fuseConsumerResults->tiledAndFusedConsumer);
+ }
+
+ transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
+ transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
+ return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ LogicalResult result =
+ applyFuseConsumer(rewriter, getOperation(),
+ state.getPayloadOps(getTarget()), transformResults);
+ return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+ : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseConsumerOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ producesHandle(getConsumer(), effects);
+ producesHandle(getFusedConsumer(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// TestTileUsingForallOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index ef42375e5286d..d55d746bd6aa9 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
}];
}
+def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Fuses the consumer of the operation pointed to by the target handle
+ using the options provided as attributes.
+ }];
+
+ let arguments =
+ (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$consumer,
+ TransformHandleTypeInterface:$fused_consumer);
+
+ let assemblyFormat = [{
+ $target attr-dict `:` functional-type(operands, results)
+ }];
+}
+
def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
>From 4af443d4fc9ab995eb17ff8ed1ec2dd5f6742ae7 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 2 May 2024 05:29:36 +0000
Subject: [PATCH 03/10] WIP unify and move common code out
---
.../SCF/Transforms/TileUsingInterface.cpp | 761 ++++++++++--------
1 file changed, 414 insertions(+), 347 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index d4147e2b0602f..eac23bfec4fa4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1173,11 +1173,12 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
return &operand;
}
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place for scf.for.
-static FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
- tensor::InsertSliceOp candidateSliceOp) {
+typedef struct helper1struct {
+ OpOperand *consumerOpOperand;
+};
+
+template <typename T>
+static helper1struct helper1(T candidateSliceOp) {
// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice.
OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
@@ -1190,377 +1191,443 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
unsigned resultNumber =
cast<OpResult>(consumerOpOperand->get()).getResultNumber();
- auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+ if (std::is_same<T, tensor::InsertSliceOp>::value) {
+ auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(candidateSliceOp);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(candidateSliceOp);
- // 2. Check consumer is not using scf.for's output as init.
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
- SmallVector<Value> dpsInits =
- llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
- return rewriter.notifyMatchFailure(
- consumerOp,
- "consumer op taking the result of scf.for as init is not supported");
- }
+ // 2. Check consumer is not using scf.for's output as init.
+ auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.for as init is not supported");
+ }
- Location loc = forOp.getLoc();
- SmallVector<Value> newOuts(forOp.getInits());
- newOuts.append(dpsInits);
-
- // 3. Create new scf.for op.
- rewriter.setInsertionPoint(consumerOp);
- auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
- forOp.getUpperBound(),
- forOp.getStep(), newOuts);
- // 4. Move the loop body to the new op.
- Block *loopBody = forOp.getBody();
- Block *newLoopBody = newforOp.getBody();
- rewriter.mergeBlocks(
- loopBody, newLoopBody,
- newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+ Location loc = forOp.getLoc();
+ SmallVector<Value> newOuts(forOp.getInits());
+ newOuts.append(dpsInits);
+
+ // 3. Create new scf.for op.
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+ forOp.getUpperBound(),
+ forOp.getStep(), newOuts);
+ // 4. Move the loop body to the new op.
+ Block *loopBody = forOp.getBody();
+ Block *newLoopBody = newforOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ helper1struct ret;
+ ret.consumerOpOperand = consumerOpOperand;
+ ret.newLoop = newforOp;
+ return ret;
+ }
+
+ /// Implementation of fusing consumer of a single slice by computing the
+ /// slice of the consumer in-place for scf.for.
+ static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+ tileAndFuseConsumerOfSliceSCFFor(RewriterBase & rewriter,
+ tensor::InsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of scf.for for the result yielded by
+ // tensor.insert_slice.
+ OpOperand *consumerOpOperand =
+ getUntiledConsumerFromSlice(candidateSliceOp);
+ if (!consumerOpOperand) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = consumerOpOperand->getOwner();
+ unsigned operandNumber = consumerOpOperand->getOperandNumber();
+ unsigned resultNumber =
+ cast<OpResult>(consumerOpOperand->get()).getResultNumber();
- // 5.a. Clone consumer after the cloned tensor.insert_slice op.
- rewriter.setInsertionPointAfter(candidateSliceOp);
- auto newForOpBlockArgsForConsumerDest =
- newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
- auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
- rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+ auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
- // 5.b. Replace all uses of the loop result with the result of the cloned
- // tensor.insert_slice.
- OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
- rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
- operandToReplace.set(candidateSliceOp.getResult());
- });
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(candidateSliceOp);
- // 6 - Perform tiling of the cloned consumer.
- rewriter.setInsertionPointAfter(clonedConsumerOp);
- FailureOr<TilingResult> tileAndFuseResult =
- tensor::replaceInsertSliceWithTiledConsumer(
- rewriter,
- cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
- clonedConsumerOp->getOpOperand(operandNumber));
- if (failed(tileAndFuseResult)) {
- return rewriter.notifyMatchFailure(clonedConsumerOp,
- "failed to tile consumer op: ");
- }
+ // 2. Check consumer is not using scf.for's output as init.
+ auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.for as init is not supported");
+ }
- // 7 - Extract offset/sizes/strides required to create the tensor.insert_slice
- // for each result of the consumer.
- SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
- // 8. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
- // 9. Try to get iter domain position from input position.
- SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- rewriter.setInsertionPointAfter(clonedConsumerOp);
- if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
- iterDomainSizes))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp, "can't get iter domain position from input position");
- }
+ Location loc = forOp.getLoc();
+ SmallVector<Value> newOuts(forOp.getInits());
+ newOuts.append(dpsInits);
+
+ // 3. Create new scf.for op.
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+ forOp.getUpperBound(),
+ forOp.getStep(), newOuts);
+ // 4. Move the loop body to the new op.
+ Block *loopBody = forOp.getBody();
+ Block *newLoopBody = newforOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ // 5.a. Clone consumer after the cloned tensor.insert_slice op.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ auto newForOpBlockArgsForConsumerDest =
+ newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
+ auto clonedConsumerOp =
+ cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // 5.b. Replace all uses of the loop result with the result of the cloned
+ // tensor.insert_slice.
+ OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+ rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+ operandToReplace.set(candidateSliceOp.getResult());
+ });
+
+ // 6 - Perform tiling of the cloned consumer.
+ rewriter.setInsertionPointAfter(clonedConsumerOp);
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter,
+ cast<OffsetSizeAndStrideOpInterface>(
+ candidateSliceOp.getOperation()),
+ clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return rewriter.notifyMatchFailure(clonedConsumerOp,
+ "failed to tile consumer op: ");
+ }
- // 10. Try to fetch the offset and size for all results of the cloned
- // consumer. This would then be used to form the corresponding
- // tensor.insert_slice later.
- unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
- SmallVector<SmallVector<OpFoldResult>> resultOffsets(
- totalNumResultsOfConsumer);
- SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
- for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
- if (failed(clonedConsumerOp.getResultTilePosition(
- rewriter, idx, iterDomainOffsets, iterDomainSizes,
- resultOffsets[idx], resultSizes[idx]))) {
+ // 7 - Extract offset/sizes/strides required to create the
+ // tensor.insert_slice for each result of the consumer.
+ SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+ // 8. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
+ // 9. Try to get iter domain position from input position.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ rewriter.setInsertionPointAfter(clonedConsumerOp);
+ if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
return rewriter.notifyMatchFailure(
clonedConsumerOp,
- "can't get result domain position from iter domain position");
+ "can't get iter domain position from input position");
}
- }
-
- // 11. Fix terminator.
- scf::YieldOp oldTerminatorOp =
- cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
- SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
- rewriter.setInsertionPointAfter(oldTerminatorOp);
- MutableArrayRef<BlockArgument> bbArgs = newforOp.getBody()->getArguments();
- for (auto [idx, v] :
- llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
- SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
- rewriter.getIndexAttr(1));
- Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- candidateSliceOp->getLoc(), v,
- bbArgs[1 + forOp.getInits().size() + idx], resultOffsets[idx],
- resultSizes[idx], strides);
- newYieldOperands.push_back(newInsertSliceOp);
- }
- rewriter.create<scf::YieldOp>(loc, newYieldOperands);
- rewriter.eraseOp(oldTerminatorOp);
-
- // 12. Replace the result of scf.for and consumer op.
- for (auto &&[oldResult, newResult] :
- llvm::zip_first(forOp.getResults(), newforOp.getResults())) {
- rewriter.replaceAllUsesWith(oldResult, newResult);
- }
-
- for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
- rewriter.replaceAllUsesWith(
- oldValue, newforOp->getResult(forOp.getInits().size() + index));
- }
-
- // 13. Need to erase the old scf.for and the cloned consumer op.
- rewriter.eraseOp(forOp);
- rewriter.eraseOp(clonedConsumerOp);
- return scf::SCFFuseConsumerOfSliceResult{
- consumerOp, tileAndFuseResult->tiledOps[0], {}};
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static OpOperand *
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
- // Step 1. Fetch the corresponding output
- Value sliceDest = candidateSliceOp.getDest();
- auto iterArg = cast<BlockArgument>(sliceDest);
- Operation *containingOp = iterArg.getOwner()->getParentOp();
- // Step 2. Check that the containing op is scf.forall.
- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
- if (!forallOp) {
- return nullptr;
- }
- Value resultingValue =
- forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
- // Step 3. Check resulting value of scf.forall has exactly one use.
- Value::use_range uses = resultingValue.getUses();
- if (!llvm::hasSingleElement(uses)) {
- return nullptr;
- }
-
- // Step 4. Get uses.
- OpOperand &operand = (*resultingValue.getUses().begin());
- Operation *consumerOp = operand.getOwner();
- // TODO: We have to init result of consumer before scf.forall, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp)) {
- return nullptr;
- }
- if (containingOp->getBlock() != consumerOp->getBlock()) {
- return nullptr;
- }
- return &operand;
-}
-
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place for scf.forall.
-static FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSliceSCFForall(
- RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
- // 1. Get the consumer of the dest.
- OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
- if (!consumerOpOperand) {
- return rewriter.notifyMatchFailure(candidateSliceOp,
- "could not fetch consumer to fuse");
- }
- Operation *consumerOp = consumerOpOperand->getOwner();
- unsigned operandNumber = consumerOpOperand->getOperandNumber();
- unsigned resultNumber =
- cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+ // 10. Try to fetch the offset and size for all results of the cloned
+ // consumer. This would then be used to form the corresponding
+ // tensor.insert_slice later.
+ unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(
+ totalNumResultsOfConsumer);
+ for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+ if (failed(clonedConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
- OpBuilder::InsertionGuard g(rewriter);
- // Using candidateSliceOp->getParentOp() because we have the following case :-
- // scf.forall.in_parallel {
- // tensor.parallel_insert_slice ...
- // }
- rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+ // 11. Fix terminator.
+ scf::YieldOp oldTerminatorOp =
+ cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+ SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+ rewriter.setInsertionPointAfter(oldTerminatorOp);
+ MutableArrayRef<BlockArgument> bbArgs = newforOp.getBody()->getArguments();
+ for (auto [idx, v] :
+ llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+ rewriter.getIndexAttr(1));
+ Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ candidateSliceOp->getLoc(), v,
+ bbArgs[1 + forOp.getInits().size() + idx], resultOffsets[idx],
+ resultSizes[idx], strides);
+ newYieldOperands.push_back(newInsertSliceOp);
+ }
+ rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+ rewriter.eraseOp(oldTerminatorOp);
- Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
- auto forallOp = cast<scf::ForallOp>(containingOp);
+ // 12. Replace the result of scf.for and consumer op.
+ for (auto &&[oldResult, newResult] :
+ llvm::zip_first(forOp.getResults(), newforOp.getResults())) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
+ }
- // 2. Check consumer is not using scf.forall's output as init.
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
- SmallVector<Value> dpsInits =
- llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
- return rewriter.notifyMatchFailure(
- consumerOp,
- "consumer op taking the result of scf.forall as init is not supported");
- }
+ for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ oldValue, newforOp->getResult(forOp.getInits().size() + index));
+ }
- Location loc = forallOp.getLoc();
- // 3. Create new scf.forall op.
- SmallVector<Value> newOuts(forallOp.getOutputs());
- newOuts.append(dpsInits);
- rewriter.setInsertionPoint(consumerOp);
- auto newforallOp = rewriter.create<scf::ForallOp>(
- loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
- forallOp.getMixedStep(), newOuts, forallOp.getMapping());
-
- // 4. Move the loop body to the new op.
- rewriter.eraseOp(newforallOp.getTerminator());
- Block *loopBody = forallOp.getBody();
- Block *newLoopBody = newforallOp.getBody();
- rewriter.mergeBlocks(
- loopBody, newLoopBody,
- newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+ // 13. Need to erase the old scf.for and the cloned consumer op.
+ rewriter.eraseOp(forOp);
+ rewriter.eraseOp(clonedConsumerOp);
+
+ return scf::SCFFuseConsumerOfSliceResult{
+ consumerOp, tileAndFuseResult->tiledOps[0], {}};
+ }
+
+ /// Fetch the first untiled consumer of a scf.forall's result which is yielded
+ /// by a tensor.parallel_insert_slice.
+ static OpOperand *getUntiledConsumerFromSlice(
+ tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // Step 1. Fetch the corresponding output
+ Value sliceDest = candidateSliceOp.getDest();
+ auto iterArg = cast<BlockArgument>(sliceDest);
+ Operation *containingOp = iterArg.getOwner()->getParentOp();
+ // Step 2. Check that the containing op is scf.forall.
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ return nullptr;
+ }
+ Value resultingValue =
+ forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+ // Step 3. Check resulting value of scf.forall has exactly one use.
+ Value::use_range uses = resultingValue.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ return nullptr;
+ }
- // 5.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
- rewriter.setInsertionPointAfter(candidateSliceOp);
- auto newForOpBlockArgsForConsumerDest =
- newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
- auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
- rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
-
- // 5.b. Replace all uses of the scf.forall's result use in the consumer with
- // the source of the cloned tensor.parallel_insert_slice.
- OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
- rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
- operandToReplace.set(candidateSliceOp.getSource());
- });
-
- // 6. Perform tiling of the cloned consumer.
- rewriter.setInsertionPoint(newforallOp.getTerminator());
- FailureOr<TilingResult> tileAndFuseResult =
- tensor::replaceInsertSliceWithTiledConsumer(
- rewriter,
- cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
- clonedConsumerOp->getOpOperand(operandNumber));
- if (failed(tileAndFuseResult)) {
- return rewriter.notifyMatchFailure(clonedConsumerOp,
- "failed to tile consumer op: ");
- }
+ // Step 4. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.forall, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp)) {
+ return nullptr;
+ }
+ if (containingOp->getBlock() != consumerOp->getBlock()) {
+ return nullptr;
+ }
+ return &operand;
+ }
+
+ /// Implementation of fusing consumer of a single slice by computing the
+ /// slice of the consumer in-place for scf.forall.
+ static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+ tileAndFuseConsumerOfSliceSCFForall(
+ RewriterBase & rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the dest.
+ OpOperand *consumerOpOperand =
+ getUntiledConsumerFromSlice(candidateSliceOp);
+ if (!consumerOpOperand) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = consumerOpOperand->getOwner();
+ unsigned operandNumber = consumerOpOperand->getOperandNumber();
+ unsigned resultNumber =
+ cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ // Using candidateSliceOp->getParentOp() because we have the following case
+ // :- scf.forall.in_parallel {
+ // tensor.parallel_insert_slice ...
+ // }
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+ Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+ auto forallOp = cast<scf::ForallOp>(containingOp);
+
+ // 2. Check consumer is not using scf.forall's output as init.
+ auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer op taking the result of "
+ "scf.forall as init is not supported");
+ }
- // 7. Extract offset/sizes/strides required to create the
- // tensor.parallel_insert_slice for each result of the consumer.
- SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
- // 8. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
- // 9. Try to get iter domain position from input position.
- SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
- if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
- iterDomainSizes))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp, "can't get iter domain position from input position");
- }
+ Location loc = forallOp.getLoc();
+ // 3. Create new scf.forall op.
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.append(dpsInits);
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+ // 4. Move the loop body to the new op.
+ rewriter.eraseOp(newforallOp.getTerminator());
+ Block *loopBody = forallOp.getBody();
+ Block *newLoopBody = newforallOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ // 5.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ auto newForOpBlockArgsForConsumerDest =
+ newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
+ auto clonedConsumerOp =
+ cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // 5.b. Replace all uses of the scf.forall's result use in the consumer with
+ // the source of the cloned tensor.parallel_insert_slice.
+ OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+ rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+ operandToReplace.set(candidateSliceOp.getSource());
+ });
+
+ // 6. Perform tiling of the cloned consumer.
+ rewriter.setInsertionPoint(newforallOp.getTerminator());
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter,
+ cast<OffsetSizeAndStrideOpInterface>(
+ candidateSliceOp.getOperation()),
+ clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return rewriter.notifyMatchFailure(clonedConsumerOp,
+ "failed to tile consumer op: ");
+ }
- // 10. Try to fetch the offset and size for all results of the cloned
- // consumer. This would then be used to form the corresponding
- // tensor.parallel_insert_slice later.
- unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
- SmallVector<SmallVector<OpFoldResult>> resultOffsets(
- totalNumResultsOfConsumer);
- SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
- for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
- if (failed(clonedConsumerOp.getResultTilePosition(
- rewriter, idx, iterDomainOffsets, iterDomainSizes,
- resultOffsets[idx], resultSizes[idx]))) {
+ // 7. Extract offset/sizes/strides required to create the
+ // tensor.parallel_insert_slice for each result of the consumer.
+ SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+ // 8. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
+ // 9. Try to get iter domain position from input position.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
+ if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
return rewriter.notifyMatchFailure(
clonedConsumerOp,
- "can't get result domain position from iter domain position");
+ "can't get iter domain position from input position");
}
- }
- // 11. Fix terminator.
- scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
- rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
- Location firstYieldOpLoc =
- (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
- MutableArrayRef<BlockArgument> bbArgs = newforallOp.getBody()->getArguments();
- for (auto [idx, v] :
- llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
- SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
- rewriter.getIndexAttr(1));
- rewriter.create<tensor::ParallelInsertSliceOp>(
- firstYieldOpLoc, v,
- bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
- resultOffsets[idx], resultSizes[idx], strides);
- }
-
- // 12. Replace the result of scf.forall and consumer op.
- for (auto &&[oldResult, newResult] :
- llvm::zip_first(forallOp.getResults(), newforallOp.getResults())) {
- rewriter.replaceAllUsesWith(oldResult, newResult);
- }
+ // 10. Try to fetch the offset and size for all results of the cloned
+ // consumer. This would then be used to form the corresponding
+ // tensor.parallel_insert_slice later.
+ unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(
+ totalNumResultsOfConsumer);
+ for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+ if (failed(clonedConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
- for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
- rewriter.replaceAllUsesWith(
- oldValue, newforallOp->getResult(forallOp.getOutputs().size() + index));
- }
+ // 11. Fix terminator.
+ scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+ rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+ Location firstYieldOpLoc =
+ (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+ MutableArrayRef<BlockArgument> bbArgs =
+ newforallOp.getBody()->getArguments();
+ for (auto [idx, v] :
+ llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+ rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ firstYieldOpLoc, v,
+ bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+ resultOffsets[idx], resultSizes[idx], strides);
+ }
- // 13. Need to erase the old scf.forall and cloned consumer.
- rewriter.eraseOp(forallOp);
- rewriter.eraseOp(clonedConsumerOp);
+ // 12. Replace the result of scf.forall and consumer op.
+ for (auto &&[oldResult, newResult] :
+ llvm::zip_first(forallOp.getResults(), newforallOp.getResults())) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
+ }
- return scf::SCFFuseConsumerOfSliceResult{
- consumerOp, tileAndFuseResult->tiledOps[0], {}};
-}
+ for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ oldValue,
+ newforallOp->getResult(forallOp.getOutputs().size() + index));
+ }
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place.
-FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
- Operation *candidateSliceOp) {
- if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
- return tileAndFuseConsumerOfSliceSCFFor(rewriter, sliceOp);
- } else if (auto sliceOp =
- dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
- return tileAndFuseConsumerOfSliceSCFForall(rewriter, sliceOp);
- } else {
- return failure();
+ // 13. Need to erase the old scf.forall and cloned consumer.
+ rewriter.eraseOp(forallOp);
+ rewriter.eraseOp(clonedConsumerOp);
+
+ return scf::SCFFuseConsumerOfSliceResult{
+ consumerOp, tileAndFuseResult->tiledOps[0], {}};
+ }
+
+ /// Implementation of fusing consumer of a single slice by computing the
+ /// slice of the consumer in-place.
+ FailureOr<scf::SCFFuseConsumerOfSliceResult>
+ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase & rewriter,
+ Operation * candidateSliceOp) {
+ if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+ return tileAndFuseConsumerOfSliceSCFFor(rewriter, sliceOp);
+ } else if (auto sliceOp =
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+ return tileAndFuseConsumerOfSliceSCFForall(rewriter, sliceOp);
+ } else {
+ return failure();
+ }
}
-}
-//===----------------------------------------------------------------------===//
-// lowerToLoopsUsingSCFForOp implementation.
-//===----------------------------------------------------------------------===//
+ //===----------------------------------------------------------------------===//
+ // lowerToLoopsUsingSCFForOp implementation.
+ //===----------------------------------------------------------------------===//
-FailureOr<SmallVector<scf::ForOp>>
-mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
- TilingInterface op) {
- // TODO: Handle cases where the op has results if needed.
- if (op->getNumResults() > 0) {
- return rewriter.notifyMatchFailure(
- op, "unable to lower to loops operations with return values");
- }
+ FailureOr<SmallVector<scf::ForOp>> mlir::scf::lowerToLoopsUsingSCFForOp(
+ RewriterBase & rewriter, TilingInterface op) {
+ // TODO: Handle cases where the op has results if needed.
+ if (op->getNumResults() > 0) {
+ return rewriter.notifyMatchFailure(
+ op, "unable to lower to loops operations with return values");
+ }
- SmallVector<Range> domain = op.getIterationDomain(rewriter);
- SmallVector<Value> ivs;
- SmallVector<scf::ForOp> loops;
- Location loc = op.getLoc();
- for (auto loopRange : domain) {
- Value offsetVal =
- getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
- Value sizeVal =
- getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
- Value strideVal =
- getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
- auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
- strideVal, ValueRange{});
- loops.push_back(loop);
- ivs.push_back(loop.getInductionVar());
- rewriter.setInsertionPoint(loop.getBody()->getTerminator());
- }
- if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
- return failure();
+ SmallVector<Range> domain = op.getIterationDomain(rewriter);
+ SmallVector<Value> ivs;
+ SmallVector<scf::ForOp> loops;
+ Location loc = op.getLoc();
+ for (auto loopRange : domain) {
+ Value offsetVal =
+ getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+ Value sizeVal =
+ getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+ Value strideVal =
+ getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
+ auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
+ strideVal, ValueRange{});
+ loops.push_back(loop);
+ ivs.push_back(loop.getInductionVar());
+ rewriter.setInsertionPoint(loop.getBody()->getTerminator());
+ }
+ if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
+ return failure();
+ }
+ return loops;
}
- return loops;
-}
>From f0dc8cce02d69971f21232ceda2b71dd76224d2c Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 7 May 2024 10:19:00 +0000
Subject: [PATCH 04/10] Unify the code for scf.for/forall fuse consumer + fix
crash
---
.../SCF/Transforms/TileUsingInterface.cpp | 734 ++++++++----------
.../tile-and-fuse-consumer.mlir | 116 ++-
2 files changed, 406 insertions(+), 444 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index eac23bfec4fa4..7affc10f3eb17 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1173,14 +1173,138 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
return &operand;
}
-typedef struct helper1struct {
- OpOperand *consumerOpOperand;
-};
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static OpOperand *
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // Step 1. Fetch the corresponding output
+ Value sliceDest = candidateSliceOp.getDest();
+ auto iterArg = cast<BlockArgument>(sliceDest);
+ Operation *containingOp = iterArg.getOwner()->getParentOp();
+ // Step 2. Check that the containing op is scf.forall.
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ return nullptr;
+ }
+ Value resultingValue =
+ forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+ // Step 3. Check resulting value of scf.forall has exactly one use.
+ Value::use_range uses = resultingValue.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ return nullptr;
+ }
+
+ // Step 4. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.forall, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp)) {
+ return nullptr;
+ }
+ if (containingOp->getBlock() != consumerOp->getBlock()) {
+ return nullptr;
+ }
+ return &operand;
+}
+
+/// This utility currently checks whether the loop either :-
+/// 1. Yields exactly one result.
+/// 2. Has consumer op as its first user and other users to be in the same
+/// containing block as that of consumer op's. Currently we clone the loop op
+/// right before the consumer op in order to maintain a valid def-use chain.
+/// This utility thus helps ensuring that no invalid IR is formed due to the
+/// same.
+static LogicalResult checkAssumptionForLoop(Operation *loopOp,
+ Operation *consumerOp) {
+ // Check if the loop op yields one result.
+ if (loopOp->getNumResults() == 1)
+ return success();
+ // Check if the consumerOp is the first user of the loopOp and if other users
+ // are in the same containing block as that of consumer op's.
+ Block *parentBlock = consumerOp->getBlock();
+ for (Operation *userOp : loopOp->getUsers()) {
+ if (userOp == consumerOp)
+ continue;
+ if (parentBlock != userOp->getBlock() ||
+ !consumerOp->isBeforeInBlock(userOp))
+ return failure();
+ }
+ return success();
+}
+
+static OpOperand *getUntiledConsumerFromSlice(Operation *sliceOp) {
+ if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+ return getUntiledConsumerFromSlice(insertSlice);
+ } else if (auto parallelInsertSlice =
+ dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+ return getUntiledConsumerFromSlice(parallelInsertSlice);
+ } else {
+ return nullptr;
+ }
+}
+
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+ TilingResult tilingResult,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+ SmallVector<OpFoldResult> &strides, unsigned initSize) {
+ scf::YieldOp oldTerminatorOp =
+ cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+ SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+ rewriter.setInsertionPointAfter(oldTerminatorOp);
+ MutableArrayRef<BlockArgument> bbArgs = newForOp.getBody()->getArguments();
+ Location loc = newForOp.getLoc();
+ for (auto [idx, v] :
+ llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+ rewriter.getIndexAttr(1));
+ Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, v, bbArgs[1 + initSize + idx], resultOffsets[idx],
+ resultSizes[idx], strides);
+ newYieldOperands.push_back(newInsertSliceOp);
+ }
+ rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+ rewriter.eraseOp(oldTerminatorOp);
+}
+
+static void fixTerminatorSCFInParallel(
+ RewriterBase &rewriter, scf::ForallOp newForallOp,
+ TilingResult tilingResult,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+ SmallVector<OpFoldResult> &strides, unsigned initSize, unsigned rank) {
+ scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
+ rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+ Location firstYieldOpLoc =
+ (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+ MutableArrayRef<BlockArgument> bbArgs = newForallOp.getBody()->getArguments();
+ for (auto [idx, v] :
+ llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+ rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ firstYieldOpLoc, v, bbArgs[rank + initSize + idx], resultOffsets[idx],
+ resultSizes[idx], strides);
+ }
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+ Operation *candidateSliceOp) {
+ if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+ candidateSliceOp))
+ return failure();
+
+ bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
-template <typename T>
-static helper1struct helper1(T candidateSliceOp) {
// 1. Get the consumer of scf.for for the result yielded by
- // tensor.insert_slice.
+ // tensor.insert_slice/parallel_insert_slice.
OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
if (!consumerOpOperand) {
return rewriter.notifyMatchFailure(candidateSliceOp,
@@ -1191,443 +1315,217 @@ static helper1struct helper1(T candidateSliceOp) {
unsigned resultNumber =
cast<OpResult>(consumerOpOperand->get()).getResultNumber();
- if (std::is_same<T, tensor::InsertSliceOp>::value) {
- auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
-
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(candidateSliceOp);
-
- // 2. Check consumer is not using scf.for's output as init.
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
- SmallVector<Value> dpsInits =
- llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
- return rewriter.notifyMatchFailure(
- consumerOp,
- "consumer op taking the result of scf.for as init is not supported");
- }
-
- Location loc = forOp.getLoc();
- SmallVector<Value> newOuts(forOp.getInits());
- newOuts.append(dpsInits);
-
- // 3. Create new scf.for op.
- rewriter.setInsertionPoint(consumerOp);
- auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
- forOp.getUpperBound(),
- forOp.getStep(), newOuts);
- // 4. Move the loop body to the new op.
- Block *loopBody = forOp.getBody();
- Block *newLoopBody = newforOp.getBody();
- rewriter.mergeBlocks(
- loopBody, newLoopBody,
- newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
-
- helper1struct ret;
- ret.consumerOpOperand = consumerOpOperand;
- ret.newLoop = newforOp;
- return ret;
- }
-
- /// Implementation of fusing consumer of a single slice by computing the
- /// slice of the consumer in-place for scf.for.
- static FailureOr<scf::SCFFuseConsumerOfSliceResult>
- tileAndFuseConsumerOfSliceSCFFor(RewriterBase & rewriter,
- tensor::InsertSliceOp candidateSliceOp) {
- // 1. Get the consumer of scf.for for the result yielded by
- // tensor.insert_slice.
- OpOperand *consumerOpOperand =
- getUntiledConsumerFromSlice(candidateSliceOp);
- if (!consumerOpOperand) {
- return rewriter.notifyMatchFailure(candidateSliceOp,
- "could not fetch consumer to fuse");
- }
- Operation *consumerOp = consumerOpOperand->getOwner();
- unsigned operandNumber = consumerOpOperand->getOperandNumber();
- unsigned resultNumber =
- cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+ Operation *oldLoopOp = nullptr;
+ SmallVector<Value> newOuts;
+ Block *oldLoopBody = nullptr;
+ unsigned initSize = 0;
+ unsigned rank = 1;
+ if (isInsertSliceOp) {
+ auto forOp = candidateSliceOp->template getParentOfType<scf::ForOp>();
+ SmallVector<Value> forOpOuts(forOp.getInits());
+ oldLoopOp = forOp;
+ newOuts = forOpOuts;
+ oldLoopBody = forOp.getBody();
+ initSize = forOp.getInits().size();
+ } else {
+ auto forallOp = candidateSliceOp->template getParentOfType<scf::ForallOp>();
+ SmallVector<Value> forallOpOuts(forallOp.getOutputs());
+ oldLoopOp = forallOp;
+ newOuts = forallOpOuts;
+ oldLoopBody = forallOp.getBody();
+ initSize = forallOp.getOutputs().size();
+ rank = forallOp.getRank();
+ }
- auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+ if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ return rewriter.notifyMatchFailure(
+ oldLoopOp, "containing loop op should either yield just one value or "
+ "have the consumer op as its first user");
+ }
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(candidateSliceOp);
+ OpBuilder::InsertionGuard g(rewriter);
- // 2. Check consumer is not using scf.for's output as init.
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
- SmallVector<Value> dpsInits =
- llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
- return rewriter.notifyMatchFailure(
- consumerOp,
- "consumer op taking the result of scf.for as init is not supported");
- }
+ // 2. Check consumer is not using scf loop's output as init.
+ auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.for as init is not supported");
+ }
+ newOuts.append(dpsInits);
- Location loc = forOp.getLoc();
- SmallVector<Value> newOuts(forOp.getInits());
- newOuts.append(dpsInits);
+ Location loc = oldLoopOp->getLoc();
- // 3. Create new scf.for op.
- rewriter.setInsertionPoint(consumerOp);
- auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+ // 3. Create new scf loop op.
+ rewriter.setInsertionPoint(consumerOp);
+ Operation *newLoopOp = nullptr;
+ Block *newLoopBody = nullptr;
+ if (isInsertSliceOp) {
+ auto forOp = cast<scf::ForOp>(oldLoopOp);
+ auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(), newOuts);
- // 4. Move the loop body to the new op.
- Block *loopBody = forOp.getBody();
- Block *newLoopBody = newforOp.getBody();
- rewriter.mergeBlocks(
- loopBody, newLoopBody,
- newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
-
- // 5.a. Clone consumer after the cloned tensor.insert_slice op.
- rewriter.setInsertionPointAfter(candidateSliceOp);
- auto newForOpBlockArgsForConsumerDest =
- newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
- auto clonedConsumerOp =
- cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
- rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
-
- // 5.b. Replace all uses of the loop result with the result of the cloned
- // tensor.insert_slice.
- OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
- rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
- operandToReplace.set(candidateSliceOp.getResult());
- });
-
- // 6 - Perform tiling of the cloned consumer.
- rewriter.setInsertionPointAfter(clonedConsumerOp);
- FailureOr<TilingResult> tileAndFuseResult =
- tensor::replaceInsertSliceWithTiledConsumer(
- rewriter,
- cast<OffsetSizeAndStrideOpInterface>(
- candidateSliceOp.getOperation()),
- clonedConsumerOp->getOpOperand(operandNumber));
- if (failed(tileAndFuseResult)) {
- return rewriter.notifyMatchFailure(clonedConsumerOp,
- "failed to tile consumer op: ");
- }
-
- // 7 - Extract offset/sizes/strides required to create the
- // tensor.insert_slice for each result of the consumer.
- SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
- // 8. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
- // 9. Try to get iter domain position from input position.
- SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- rewriter.setInsertionPointAfter(clonedConsumerOp);
- if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
- iterDomainSizes))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp,
- "can't get iter domain position from input position");
- }
-
- // 10. Try to fetch the offset and size for all results of the cloned
- // consumer. This would then be used to form the corresponding
- // tensor.insert_slice later.
- unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
- SmallVector<SmallVector<OpFoldResult>> resultOffsets(
- totalNumResultsOfConsumer);
- SmallVector<SmallVector<OpFoldResult>> resultSizes(
- totalNumResultsOfConsumer);
- for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
- if (failed(clonedConsumerOp.getResultTilePosition(
- rewriter, idx, iterDomainOffsets, iterDomainSizes,
- resultOffsets[idx], resultSizes[idx]))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp,
- "can't get result domain position from iter domain position");
- }
- }
-
- // 11. Fix terminator.
- scf::YieldOp oldTerminatorOp =
- cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
- SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
- rewriter.setInsertionPointAfter(oldTerminatorOp);
- MutableArrayRef<BlockArgument> bbArgs = newforOp.getBody()->getArguments();
- for (auto [idx, v] :
- llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
- SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
- rewriter.getIndexAttr(1));
- Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- candidateSliceOp->getLoc(), v,
- bbArgs[1 + forOp.getInits().size() + idx], resultOffsets[idx],
- resultSizes[idx], strides);
- newYieldOperands.push_back(newInsertSliceOp);
- }
- rewriter.create<scf::YieldOp>(loc, newYieldOperands);
- rewriter.eraseOp(oldTerminatorOp);
-
- // 12. Replace the result of scf.for and consumer op.
- for (auto &&[oldResult, newResult] :
- llvm::zip_first(forOp.getResults(), newforOp.getResults())) {
- rewriter.replaceAllUsesWith(oldResult, newResult);
- }
+ newLoopOp = newForOp;
+ newLoopBody = newForOp.getBody();
+ } else {
+ auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+ auto newForallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ newLoopOp = newForallOp;
+ rewriter.eraseOp(newForallOp.getTerminator());
+ newLoopBody = newForallOp.getBody();
+ }
- for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
- rewriter.replaceAllUsesWith(
- oldValue, newforOp->getResult(forOp.getInits().size() + index));
+ // 4. Move the loop body to the new op.
+ unsigned oldNumArguments = oldLoopBody->getNumArguments();
+ rewriter.mergeBlocks(oldLoopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(oldNumArguments));
+
+ // 5.a. Clone consumer after the cloned
+ // tensor.insert_slice/parallel_insert_slice op.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ auto newForOpBlockArgsForConsumerDest =
+ newLoopBody->getArguments().drop_front(oldNumArguments);
+ auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // 5.b. Replace all uses of the loop result with the result of the cloned
+ // tensor.insert_slice/parallel_insert_slice.
+ OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+ rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+ if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+ operandToReplace.set(sliceOp.getResult());
+ } else if (auto sliceOp =
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+ operandToReplace.set(sliceOp.getSource());
}
+ });
- // 13. Need to erase the old scf.for and the cloned consumer op.
- rewriter.eraseOp(forOp);
- rewriter.eraseOp(clonedConsumerOp);
-
- return scf::SCFFuseConsumerOfSliceResult{
- consumerOp, tileAndFuseResult->tiledOps[0], {}};
- }
-
- /// Fetch the first untiled consumer of a scf.forall's result which is yielded
- /// by a tensor.parallel_insert_slice.
- static OpOperand *getUntiledConsumerFromSlice(
- tensor::ParallelInsertSliceOp candidateSliceOp) {
- // Step 1. Fetch the corresponding output
- Value sliceDest = candidateSliceOp.getDest();
- auto iterArg = cast<BlockArgument>(sliceDest);
- Operation *containingOp = iterArg.getOwner()->getParentOp();
- // Step 2. Check that the containing op is scf.forall.
- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
- if (!forallOp) {
- return nullptr;
- }
- Value resultingValue =
- forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
- // Step 3. Check resulting value of scf.forall has exactly one use.
- Value::use_range uses = resultingValue.getUses();
- if (!llvm::hasSingleElement(uses)) {
- return nullptr;
- }
+ // 6 - Perform tiling of the cloned consumer.
+ if (isInsertSliceOp) {
+ rewriter.setInsertionPointAfter(clonedConsumerOp);
+ } else {
+ rewriter.setInsertionPoint(cast<scf::ForallOp>(newLoopOp).getTerminator());
+ }
+ auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return rewriter.notifyMatchFailure(clonedConsumerOp,
+ "failed to tile consumer op: ");
+ }
- // Step 4. Get uses.
- OpOperand &operand = (*resultingValue.getUses().begin());
- Operation *consumerOp = operand.getOwner();
- // TODO: We have to init result of consumer before scf.forall, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp)) {
- return nullptr;
- }
- if (containingOp->getBlock() != consumerOp->getBlock()) {
- return nullptr;
- }
- return &operand;
- }
-
- /// Implementation of fusing consumer of a single slice by computing the
- /// slice of the consumer in-place for scf.forall.
- static FailureOr<scf::SCFFuseConsumerOfSliceResult>
- tileAndFuseConsumerOfSliceSCFForall(
- RewriterBase & rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
- // 1. Get the consumer of the dest.
- OpOperand *consumerOpOperand =
- getUntiledConsumerFromSlice(candidateSliceOp);
- if (!consumerOpOperand) {
- return rewriter.notifyMatchFailure(candidateSliceOp,
- "could not fetch consumer to fuse");
- }
- Operation *consumerOp = consumerOpOperand->getOwner();
- unsigned operandNumber = consumerOpOperand->getOperandNumber();
- unsigned resultNumber =
- cast<OpResult>(consumerOpOperand->get()).getResultNumber();
-
- OpBuilder::InsertionGuard g(rewriter);
- // Using candidateSliceOp->getParentOp() because we have the following case
- // :- scf.forall.in_parallel {
- // tensor.parallel_insert_slice ...
- // }
- rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
-
- Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
- auto forallOp = cast<scf::ForallOp>(containingOp);
-
- // 2. Check consumer is not using scf.forall's output as init.
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
- SmallVector<Value> dpsInits =
- llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
- return rewriter.notifyMatchFailure(consumerOp,
- "consumer op taking the result of "
- "scf.forall as init is not supported");
- }
+ // 7 - Extract offset/sizes/strides required to create the
+ // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
+ SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
- Location loc = forallOp.getLoc();
- // 3. Create new scf.forall op.
- SmallVector<Value> newOuts(forallOp.getOutputs());
- newOuts.append(dpsInits);
- rewriter.setInsertionPoint(consumerOp);
- auto newforallOp = rewriter.create<scf::ForallOp>(
- loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
- forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ // 8. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
- // 4. Move the loop body to the new op.
- rewriter.eraseOp(newforallOp.getTerminator());
- Block *loopBody = forallOp.getBody();
- Block *newLoopBody = newforallOp.getBody();
- rewriter.mergeBlocks(
- loopBody, newLoopBody,
- newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
-
- // 5.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
- rewriter.setInsertionPointAfter(candidateSliceOp);
- auto newForOpBlockArgsForConsumerDest =
- newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
- auto clonedConsumerOp =
- cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
- rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
-
- // 5.b. Replace all uses of the scf.forall's result use in the consumer with
- // the source of the cloned tensor.parallel_insert_slice.
- OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
- rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
- operandToReplace.set(candidateSliceOp.getSource());
- });
-
- // 6. Perform tiling of the cloned consumer.
- rewriter.setInsertionPoint(newforallOp.getTerminator());
- FailureOr<TilingResult> tileAndFuseResult =
- tensor::replaceInsertSliceWithTiledConsumer(
- rewriter,
- cast<OffsetSizeAndStrideOpInterface>(
- candidateSliceOp.getOperation()),
- clonedConsumerOp->getOpOperand(operandNumber));
- if (failed(tileAndFuseResult)) {
- return rewriter.notifyMatchFailure(clonedConsumerOp,
- "failed to tile consumer op: ");
- }
+ // 9. Try to get iter domain position from input position.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- // 7. Extract offset/sizes/strides required to create the
- // tensor.parallel_insert_slice for each result of the consumer.
- SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
- // 8. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
- // 9. Try to get iter domain position from input position.
- SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ if (isInsertSliceOp) {
+ rewriter.setInsertionPointAfter(clonedConsumerOp);
+ } else {
rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
- if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
- iterDomainSizes))) {
+ }
+ if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp, "can't get iter domain position from input position");
+ }
+
+ // 10. Try to fetch the offset and size for all results of the cloned
+ // consumer. This would then be used to form the corresponding
+ // tensor.insert_slice/parallel_insert_slice later.
+ unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+ for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+ if (failed(clonedConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
return rewriter.notifyMatchFailure(
clonedConsumerOp,
- "can't get iter domain position from input position");
+ "can't get result domain position from iter domain position");
}
+ }
- // 10. Try to fetch the offset and size for all results of the cloned
- // consumer. This would then be used to form the corresponding
- // tensor.parallel_insert_slice later.
- unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
- SmallVector<SmallVector<OpFoldResult>> resultOffsets(
- totalNumResultsOfConsumer);
- SmallVector<SmallVector<OpFoldResult>> resultSizes(
- totalNumResultsOfConsumer);
- for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
- if (failed(clonedConsumerOp.getResultTilePosition(
- rewriter, idx, iterDomainOffsets, iterDomainSizes,
- resultOffsets[idx], resultSizes[idx]))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp,
- "can't get result domain position from iter domain position");
- }
- }
+ if (isInsertSliceOp) {
+ fixTerminatorSCFYield(rewriter, cast<scf::ForOp>(newLoopOp),
+ *tileAndFuseResult, resultOffsets, resultSizes,
+ strides, initSize);
+ } else {
+ fixTerminatorSCFInParallel(rewriter, cast<scf::ForallOp>(newLoopOp),
+ *tileAndFuseResult, resultOffsets, resultSizes,
+ strides, initSize, rank);
+ }
- // 11. Fix terminator.
- scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
- rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
- Location firstYieldOpLoc =
- (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
- MutableArrayRef<BlockArgument> bbArgs =
- newforallOp.getBody()->getArguments();
- for (auto [idx, v] :
- llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
- SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
- rewriter.getIndexAttr(1));
- rewriter.create<tensor::ParallelInsertSliceOp>(
- firstYieldOpLoc, v,
- bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
- resultOffsets[idx], resultSizes[idx], strides);
- }
+ // 12. Replace the result of scf loop and consumer op with new loop's results.
+ for (auto &&[oldResult, newResult] :
+ llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
+ }
- // 12. Replace the result of scf.forall and consumer op.
- for (auto &&[oldResult, newResult] :
- llvm::zip_first(forallOp.getResults(), newforallOp.getResults())) {
- rewriter.replaceAllUsesWith(oldResult, newResult);
- }
+ for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(oldValue,
+ newLoopOp->getResult(initSize + index));
+ }
- for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
- rewriter.replaceAllUsesWith(
- oldValue,
- newforallOp->getResult(forallOp.getOutputs().size() + index));
- }
+ // 13. Need to erase the old scf loop and the cloned consumer op.
+ rewriter.eraseOp(oldLoopOp);
+ rewriter.eraseOp(clonedConsumerOp);
- // 13. Need to erase the old scf.forall and cloned consumer.
- rewriter.eraseOp(forallOp);
- rewriter.eraseOp(clonedConsumerOp);
+ return scf::SCFFuseConsumerOfSliceResult{
+ consumerOp, tileAndFuseResult->tiledOps[0], {}};
+}
- return scf::SCFFuseConsumerOfSliceResult{
- consumerOp, tileAndFuseResult->tiledOps[0], {}};
- }
+//===----------------------------------------------------------------------===//
+// lowerToLoopsUsingSCFForOp implementation.
+//===----------------------------------------------------------------------===//
- /// Implementation of fusing consumer of a single slice by computing the
- /// slice of the consumer in-place.
- FailureOr<scf::SCFFuseConsumerOfSliceResult>
- mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase & rewriter,
- Operation * candidateSliceOp) {
- if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
- return tileAndFuseConsumerOfSliceSCFFor(rewriter, sliceOp);
- } else if (auto sliceOp =
- dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
- return tileAndFuseConsumerOfSliceSCFForall(rewriter, sliceOp);
- } else {
- return failure();
- }
+FailureOr<SmallVector<scf::ForOp>>
+mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
+ TilingInterface op) {
+ // TODO: Handle cases where the op has results if needed.
+ if (op->getNumResults() > 0) {
+ return rewriter.notifyMatchFailure(
+ op, "unable to lower to loops operations with return values");
}
- //===----------------------------------------------------------------------===//
- // lowerToLoopsUsingSCFForOp implementation.
- //===----------------------------------------------------------------------===//
-
- FailureOr<SmallVector<scf::ForOp>> mlir::scf::lowerToLoopsUsingSCFForOp(
- RewriterBase & rewriter, TilingInterface op) {
- // TODO: Handle cases where the op has results if needed.
- if (op->getNumResults() > 0) {
- return rewriter.notifyMatchFailure(
- op, "unable to lower to loops operations with return values");
- }
-
- SmallVector<Range> domain = op.getIterationDomain(rewriter);
- SmallVector<Value> ivs;
- SmallVector<scf::ForOp> loops;
- Location loc = op.getLoc();
- for (auto loopRange : domain) {
- Value offsetVal =
- getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
- Value sizeVal =
- getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
- Value strideVal =
- getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
- auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
- strideVal, ValueRange{});
- loops.push_back(loop);
- ivs.push_back(loop.getInductionVar());
- rewriter.setInsertionPoint(loop.getBody()->getTerminator());
- }
- if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
- return failure();
- }
- return loops;
+ SmallVector<Range> domain = op.getIterationDomain(rewriter);
+ SmallVector<Value> ivs;
+ SmallVector<scf::ForOp> loops;
+ Location loc = op.getLoc();
+ for (auto loopRange : domain) {
+ Value offsetVal =
+ getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+ Value sizeVal =
+ getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+ Value strideVal =
+ getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
+ auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
+ strideVal, ValueRange{});
+ loops.push_back(loop);
+ ivs.push_back(loop.getInductionVar());
+ rewriter.setInsertionPoint(loop.getBody()->getTerminator());
+ }
+ if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
+ return failure();
}
+ return loops;
+}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 3d60e32bfa0cc..f2b64a0d54438 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -190,31 +190,33 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
- %c4 = arith.constant 4 : index
- %c64 = arith.constant 64 : index
- %c0 = arith.constant 0 : index
- %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
- %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
- %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
- %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
- tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
}
+ %1 = tensor.empty() : tensor<64x64xf32>
+ %2 = tensor.empty() : tensor<64x64xf32>
+ %3 = tensor.empty() : tensor<64x64xf32>
+ %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
+ %6 = arith.mulf %in, %in_0 : f32
+ %7 = arith.subf %out, %6 : f32
+ %8 = arith.addf %out_1, %in : f32
+ linalg.yield %7, %8 : f32, f32
+ } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+ %5 = tensor.empty() : tensor<2048xf32>
+ %unpack = tensor.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
+ return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
}
- %in_operand_2 = tensor.empty() : tensor<64x64xf32>
- %out_operand_3 = tensor.empty() : tensor<64x64xf32>
- %out_operand_4 = tensor.empty() : tensor<64x64xf32>
- %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64x64xf32>, tensor<64x64xf32>) {
- ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
- %13 = arith.mulf %in, %in_16 : f32
- %14 = arith.subf %out_0, %13 : f32
- %15 = arith.addf %out_1, %in : f32
- linalg.yield %14, %15 : f32, f32
- } -> (tensor<64x64xf32>, tensor<64x64xf32>)
- return %2#1 : tensor<64x64xf32>
- }
}
module attributes {transform.with_named_sequence} {
@@ -232,10 +234,11 @@ module attributes {transform.with_named_sequence} {
// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
-// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG3]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]])
// CHECK-SAME: {
// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
@@ -255,4 +258,65 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: }
// CHECK: }
-// CHECK: return %[[FINAL_RESULT]]#3 :
+// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32>
+// CHECK: return %[[FINAL_RESULT]]#3, %[[UNPACK]] :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %output = tensor.empty() : tensor<2048xf32>
+ %unpack = tensor.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
+ return %unpack : tensor<2048xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[UNPACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
+// CHECK: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_MAP]](%[[IV1]])
+// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
+// CHECK: %[[TILED_UNPACK_SRC:.*]] = tensor.extract_slice %[[GENERIC_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[TILED_UNPACK_SRC]]
+// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
+// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#1 :
>From 5afb645a6232b3ca270c5aebe678d75b09f414d6 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 10 May 2024 07:15:32 +0000
Subject: [PATCH 05/10] [MLIR][Tensor] Add consumer tiling implementation for
tensor.unpack
-- This commit adds tiling implementation for tensor.unpack as a consumer.
Signed-off-by: Abhishek Varma <avarma094 at gmail.com>
---
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 109 ++++++++++++++++++
1 file changed, 109 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 296c5fc7a5c2b..c94db76671ec8 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -469,6 +469,115 @@ struct UnPackOpTiling
return failure();
return tilingResult.value();
}
+
+ LogicalResult getIterationDomainTileFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
+ auto unPackOp = cast<UnPackOp>(op);
+ Location loc = unPackOp.getLoc();
+
+ int64_t numTiles = unPackOp.getInnerDimsPos().size();
+ auto destOffsets = offsets.drop_back(numTiles);
+ auto destSizes = sizes.drop_back(numTiles);
+ // The tiling is applied on interchanged dimensions. We have to undo the
+ // interchange to map sizes and offsets to the original input.
+ int64_t outputRank = unPackOp.getDestRank();
+ SmallVector<OpFoldResult> origOffsets(destOffsets.begin(),
+ destOffsets.end());
+ SmallVector<OpFoldResult> origSizes(destSizes.begin(), destSizes.end());
+ applyPermToRange(origOffsets, origSizes,
+ invertPermutationVector(unPackOp.getOuterDimsPerm()));
+
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+ unPackOp.getDimAndTileMapping();
+
+ for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
+ using AV = affine::AffineValueExpr;
+ affine::AffineBuilder ab(b, loc);
+ AffineExpr dim0, dim1, sym;
+ bindDims(b.getContext(), dim0, dim1);
+ bindSymbols(b.getContext(), sym);
+ if (dimAndTileMapping.count(dim)) {
+ // If the data dimension is tiled, the i-th index is the product of
+ // offset_i and tile_i, and the i-th size is the product of sizes_i and
+ // tile_i.
+ auto avOffset = AV(dim0).bind(origOffsets[dim]);
+ auto avSize = AV(dim0).bind(origSizes[dim]);
+ auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
+ resultOffsets.push_back(ab.mul(avOffset, avTileSize));
+ resultSizes.push_back(ab.mul(avSize, avTileSize));
+ } else {
+ resultOffsets.push_back(origOffsets[dim]);
+ resultSizes.push_back(origSizes[dim]);
+ }
+ }
+ return success();
+ }
+
+ FailureOr<TilingResult>
+ getTiledImplementationAsConsumer(Operation *op, OpBuilder &b,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
+ auto unPackOp = cast<UnPackOp>(op);
+ Location loc = unPackOp.getLoc();
+
+ // Fetch offset/size for creating the slice of the dest operand of
+ // unpack op.
+ SmallVector<OpFoldResult> outputOffsets, outputSizes;
+ if (failed(getIterationDomainTileFromOperandTile(
+ op, b, 0, offsets, sizes, outputOffsets, outputSizes)))
+ return failure();
+
+ auto oneAttr = b.getI64IntegerAttr(1);
+ int64_t outputRank = unPackOp.getDestRank();
+ SmallVector<OpFoldResult> strides(outputRank, oneAttr);
+
+ SmallVector<Value> tiledOperands;
+ // Create slice of the dest operand.
+ auto extractDestSlice = b.create<ExtractSliceOp>(
+ loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
+ tiledOperands.push_back(extractDestSlice);
+
+ SmallVector<OpFoldResult> inputOffsets, inputSizes;
+ strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
+ // Create slice of the source operand.
+ auto extractSourceSlice = b.create<ExtractSliceOp>(
+ loc, unPackOp.getSource(), offsets, sizes, strides);
+ tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
+ for (auto tile : unPackOp.getInnerTiles())
+ tiledOperands.push_back(tile);
+
+ // Create tiled unpack op.
+ Operation *tiledUnPackOp =
+ b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()},
+ tiledOperands, op->getAttrs());
+
+ return TilingResult{{tiledUnPackOp},
+ SmallVector<Value>(tiledUnPackOp->getResults())};
+ }
+
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ auto unPackOp = cast<UnPackOp>(op);
+ // tensor.unpack op is fusible (as a consumer) only if inner dims are not
+ // tiled.
+ int64_t numTiles = unPackOp.getInnerDimsPos().size();
+ for (auto iter :
+ llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
+ if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
+ return failure();
+ }
+
+ FailureOr<TilingResult> tilingResult =
+ getTiledImplementationAsConsumer(unPackOp, b, offsets, sizes);
+
+ if (failed(tilingResult))
+ return failure();
+ return tilingResult.value();
+ }
};
} // namespace
>From 955fc67376268025003959fa57e5485082f35375 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 17 May 2024 08:29:52 +0000
Subject: [PATCH 06/10] Address review comments
---
.../SCF/Transforms/TileUsingInterface.h | 5 +-
.../Dialect/Tensor/Transforms/Transforms.h | 5 +-
.../mlir/Interfaces/TilingInterface.td | 8 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 +-
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 23 ++-
.../SCF/Transforms/TileUsingInterface.cpp | 174 +++++++++---------
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 52 +++---
.../TestTilingInterfaceTransformOps.cpp | 6 +-
8 files changed, 139 insertions(+), 138 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index c744a17ae9191..aaa2dbdbcd947 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -245,8 +245,9 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
/// value but does not delete the slice operation.
struct SCFFuseConsumerOfSliceResult {
- Operation *origConsumer; // Original untiled consumer.
- Operation *tiledAndFusedConsumer; // Tiled and fused consumer op.
+ OpOperand *origConsumerOperand; // Original untiled consumer's operand.
+ OpOperand
+ *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 98447cf62900d..bf6c88f7b77a8 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -32,10 +32,7 @@ FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
/// Method to swap an `tensor.insert_slice` with its consumer when the
-/// consumer implements the `TilingInterface`. The pattern itself does not
-/// provide a mechanism to control where the application happens. With use of
-/// transform dialect that control is done within the transform dialect. Other
-/// use cases can inherit from this pattern and add necessary controls.
+/// consumer implements the `TilingInterface`.
FailureOr<TilingResult>
replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
OffsetSizeAndStrideOpInterface sliceOp,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 84f7dec2f4003..fbd1021ce67fe 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -89,8 +89,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
"unsigned":$resultNumber,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
- "SmallVectorImpl<OpFoldResult> &":$resultOffsets,
- "SmallVectorImpl<OpFoldResult> &":$resultSizes),
+ "SmallVector<OpFoldResult> &":$resultOffsets,
+ "SmallVector<OpFoldResult> &":$resultSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -98,8 +98,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
>,
InterfaceMethod<
/*desc=*/[{
- Method to return the position of iteration domain tile computed by the
- tiled operation.
+ Method to return the the tile of the iteration domain where
+ values from the given tile of the operand are used
}],
/*retType=*/"::mlir::LogicalResult",
/*methodName=*/"getIterationDomainTileFromOperandTile",
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 03716eaaa6358..e5f83331baf81 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2454,8 +2454,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
LogicalResult SoftmaxOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
- SmallVectorImpl<OpFoldResult> &resultSizes) {
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 71e9c3771dcde..093c8b6347ea7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,9 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
+ /// Utility to fetch the offsets and sizes when applied as per the indexing
+ /// map of the linalg op. This helps in fusing the linalg op as a consumer of
+ /// a given slice op.
void
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
ArrayRef<OpFoldResult> offsets,
@@ -158,8 +161,8 @@ struct LinalgOpTilingInterface
}
}
- /// Return the details of the output tile generated by the tiled
- /// implementation.
+ /// Method to return the position of the result tile computed by the tiled
+ /// operation.
LogicalResult getIterationDomainTileFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
@@ -174,9 +177,9 @@ struct LinalgOpTilingInterface
AffineMap indexingMap =
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
if (!indexingMap.isProjectedPermutation()) {
- return emitError(op->getLoc(),
- "unhandled get iter domain position when operand is not "
- "accessed using a permuted projection");
+ return op->emitError()
+ << "unhandled get iter domain position when operand is not "
+ "accessed using a permuted projection";
}
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
@@ -190,8 +193,8 @@ struct LinalgOpTilingInterface
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVectorImpl<OpFoldResult> &resultOffsets,
- SmallVectorImpl<OpFoldResult> &resultSizes) const {
+ SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) const {
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
@@ -212,6 +215,8 @@ struct LinalgOpTilingInterface
return success();
}
+ /// Method to generate the tiled implementation of an operation from operand
+ /// tile position.
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
@@ -219,9 +224,7 @@ struct LinalgOpTilingInterface
auto tilingInterfaceOp = cast<TilingInterface>(op);
if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
- return emitError(
- op->getLoc(),
- "unable to obtain the iter domain position of the operation.");
+ return failure();
}
return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
mappedSizes);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 7affc10f3eb17..8064e6f524327 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1106,10 +1106,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// tileAndFuseConsumerUsingSCF implementation.
//===----------------------------------------------------------------------===//
-/// A utility function that checks whether the passed value has only one user.
-/// In case the defining operation is a tensor.insert_slice, it checks if the
-/// user is scf.yield.
-static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+/// A utility function that checks whether the only use of the result of a
+/// tensor.insert_slice op is in a scf.yield op.
+static LogicalResult
+checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
+ Value result = candidateSliceOp.getResult();
Value::use_range uses = result.getUses();
if (!llvm::hasSingleElement(uses)) {
LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
@@ -1131,83 +1132,70 @@ static LogicalResult checkAssumptionForFusingConsumer(Value result) {
return success();
}
-/// Fetch the first untiled consumer of a scf.for's result which is yielded by
-/// a tensor.insert_slice. This function makes the following assumptions :-
+/// Fetches the OpOperand of the only user (and use) of the value `val` which
+/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
+/// failure otherwise.
+static FailureOr<OpOperand *> getConsumerFromUses(Value val,
+ Block *containingOpBlock) {
+ // Step 1. Check that the value has exactly one use.
+ if (!llvm::hasSingleElement(val.getUses()))
+ return failure();
+ // Step 2. Get uses.
+ OpOperand &operand = (*val.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp))
+ return failure();
+ if (containingOpBlock != consumerOp->getBlock())
+ return failure();
+ return &operand;
+}
+
+/// Fetch the untiled consumer of a scf.for's result which is yielded by a
+/// tensor.insert_slice. This function makes the following assumptions :
/// 1. tensor.insert_slice has scf.yield as its only user.
/// 2. scf.for's corresponding result has only one use.
-static OpOperand *
+static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+ if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
+ return failure();
Value sliceResult = candidateSliceOp.getResult();
- if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
- return nullptr;
- }
// Step 1. Fetch the corresponding output.
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
unsigned resultNumber = yieldOpOperand.getOperandNumber();
// Step 2. Check containing op is scf.for.
Operation *containingOp = candidateSliceOp->getParentOp();
auto forOp = dyn_cast<scf::ForOp>(containingOp);
- if (!forOp) {
- return nullptr;
- }
+ if (!forOp)
+ return failure();
Value resultingValue = forOp->getResult(resultNumber);
- // Step 3. Check resulting value of scf.for has exactly one use.
- if (!llvm::hasSingleElement(resultingValue.getUses())) {
- return nullptr;
- }
-
- // Step 4. Get uses.
- OpOperand &operand = (*resultingValue.getUses().begin());
- Operation *consumerOp = operand.getOwner();
- // TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp)) {
- return nullptr;
- }
- if (containingOp->getBlock() != consumerOp->getBlock()) {
- return nullptr;
- }
- return &operand;
+ return getConsumerFromUses(resultingValue, containingOp->getBlock());
}
/// Fetch the first untiled consumer of a scf.forall's result which is yielded
/// by a tensor.parallel_insert_slice.
-static OpOperand *
+static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
// Step 1. Fetch the corresponding output
Value sliceDest = candidateSliceOp.getDest();
- auto iterArg = cast<BlockArgument>(sliceDest);
+ auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+ if (!iterArg)
+ return failure();
Operation *containingOp = iterArg.getOwner()->getParentOp();
+ if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
+ return failure();
// Step 2. Check that the containing op is scf.forall.
auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
- if (!forallOp) {
- return nullptr;
- }
+ if (!forallOp)
+ return failure();
Value resultingValue =
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
- // Step 3. Check resulting value of scf.forall has exactly one use.
- Value::use_range uses = resultingValue.getUses();
- if (!llvm::hasSingleElement(uses)) {
- return nullptr;
- }
- // Step 4. Get uses.
- OpOperand &operand = (*resultingValue.getUses().begin());
- Operation *consumerOp = operand.getOwner();
- // TODO: We have to init result of consumer before scf.forall, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp)) {
- return nullptr;
- }
- if (containingOp->getBlock() != consumerOp->getBlock()) {
- return nullptr;
- }
- return &operand;
+ return getConsumerFromUses(resultingValue, containingOp->getBlock());
}
/// This utility currently checks whether the loop either :-
@@ -1235,60 +1223,68 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
return success();
}
-static OpOperand *getUntiledConsumerFromSlice(Operation *sliceOp) {
+/// A utility to fetch an untiled consumer of
+/// tensor.insert_slice/tensor.parallel_insert_slice.
+static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(insertSlice);
} else if (auto parallelInsertSlice =
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(parallelInsertSlice);
} else {
- return nullptr;
+ return failure();
}
}
-static void
-fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
- TilingResult tilingResult,
- SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
- SmallVector<SmallVector<OpFoldResult>> &resultSizes,
- SmallVector<OpFoldResult> &strides, unsigned initSize) {
+/// After fusing consumer into scf.for we want to modify the scf.yield operation
+/// to reflect the same by returning the values yielded by the tiled consumer.
+static void fixTerminatorSCFYield(
+ RewriterBase &rewriter, scf::ForOp newForOp, TilingResult tilingResult,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+ SmallVector<OpFoldResult> &strides, ArrayRef<BlockArgument> bbArgs) {
scf::YieldOp oldTerminatorOp =
cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
- SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+ unsigned totalOldResults = oldTerminatorOp->getNumResults();
+ unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+ SmallVector<Value> newYieldOperands;
+ newYieldOperands.reserve(totalOldResults + totalTiledResults);
+ for (auto oldResult : oldTerminatorOp.getResults()) {
+ newYieldOperands.push_back(oldResult);
+ }
rewriter.setInsertionPointAfter(oldTerminatorOp);
- MutableArrayRef<BlockArgument> bbArgs = newForOp.getBody()->getArguments();
Location loc = newForOp.getLoc();
for (auto [idx, v] :
llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
rewriter.getIndexAttr(1));
Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, v, bbArgs[1 + initSize + idx], resultOffsets[idx],
- resultSizes[idx], strides);
+ loc, v, bbArgs[idx], resultOffsets[idx], resultSizes[idx], strides);
newYieldOperands.push_back(newInsertSliceOp);
}
rewriter.create<scf::YieldOp>(loc, newYieldOperands);
rewriter.eraseOp(oldTerminatorOp);
}
+/// After fusing consumer into scf.forall we want to yield each of the resulting
+/// values by the tiled consumer within scf.forall.in_parallel region.
static void fixTerminatorSCFInParallel(
RewriterBase &rewriter, scf::ForallOp newForallOp,
TilingResult tilingResult,
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
SmallVector<SmallVector<OpFoldResult>> &resultSizes,
- SmallVector<OpFoldResult> &strides, unsigned initSize, unsigned rank) {
+ SmallVector<OpFoldResult> &strides, ArrayRef<BlockArgument> bbArgs) {
scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
Location firstYieldOpLoc =
(*(newTerminatorOp.getYieldingOps().begin())).getLoc();
- MutableArrayRef<BlockArgument> bbArgs = newForallOp.getBody()->getArguments();
for (auto [idx, v] :
llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
- firstYieldOpLoc, v, bbArgs[rank + initSize + idx], resultOffsets[idx],
- resultSizes[idx], strides);
+ firstYieldOpLoc, v, bbArgs[idx], resultOffsets[idx], resultSizes[idx],
+ strides);
}
}
@@ -1305,15 +1301,22 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
- OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
- if (!consumerOpOperand) {
+ FailureOr<OpOperand *> maybeConsumerOpOperand =
+ getUntiledConsumerFromSlice(candidateSliceOp);
+ if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
}
+ OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
Operation *consumerOp = consumerOpOperand->getOwner();
unsigned operandNumber = consumerOpOperand->getOperandNumber();
- unsigned resultNumber =
- cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+ unsigned resultNumber = 0;
+ if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
+ resultNumber = producerResult.getResultNumber();
+ } else {
+ return rewriter.notifyMatchFailure(
+ consumerOp, "consumer op's operand doesn't seem to be an OpResult");
+ }
Operation *oldLoopOp = nullptr;
SmallVector<Value> newOuts;
@@ -1466,13 +1469,16 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
}
if (isInsertSliceOp) {
- fixTerminatorSCFYield(rewriter, cast<scf::ForOp>(newLoopOp),
- *tileAndFuseResult, resultOffsets, resultSizes,
- strides, initSize);
+ auto newForOp = cast<scf::ForOp>(newLoopOp);
+ fixTerminatorSCFYield(
+ rewriter, newForOp, *tileAndFuseResult, resultOffsets, resultSizes,
+ strides, newForOp.getBody()->getArguments().drop_front(1 + initSize));
} else {
- fixTerminatorSCFInParallel(rewriter, cast<scf::ForallOp>(newLoopOp),
- *tileAndFuseResult, resultOffsets, resultSizes,
- strides, initSize, rank);
+ auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+ fixTerminatorSCFInParallel(
+ rewriter, newForallOp, *tileAndFuseResult, resultOffsets, resultSizes,
+ strides,
+ newForallOp.getBody()->getArguments().drop_front(rank + initSize));
}
// 12. Replace the result of scf loop and consumer op with new loop's results.
@@ -1491,7 +1497,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
rewriter.eraseOp(clonedConsumerOp);
return scf::SCFFuseConsumerOfSliceResult{
- consumerOp, tileAndFuseResult->tiledOps[0], {}};
+ consumerOpOperand,
+ &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
+ tileAndFuseResult->tiledOps};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index c94db76671ec8..3518da839f9d2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVectorImpl<OpFoldResult> &resultOffsets,
- SmallVectorImpl<OpFoldResult> &resultSizes) const {
+ SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) const {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVectorImpl<OpFoldResult> &resultOffsets,
- SmallVectorImpl<OpFoldResult> &resultSizes) const {
+ SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) const {
// The iteration domain is over outer dimensions of packed layout. In this
// context, the outer dimensions of `resultOffsets` are `offsets`. The
// inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVectorImpl<OpFoldResult> &resultOffsets,
- SmallVectorImpl<OpFoldResult> &resultSizes) const {
+ SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) const {
resultOffsets = llvm::to_vector(offsets);
resultSizes = llvm::to_vector(sizes);
return success();
@@ -470,6 +470,8 @@ struct UnPackOpTiling
return tilingResult.value();
}
+ /// Method to return the position of iteration domain tile computed by the
+ /// tiled operation.
LogicalResult getIterationDomainTileFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
@@ -516,11 +518,20 @@ struct UnPackOpTiling
return success();
}
- FailureOr<TilingResult>
- getTiledImplementationAsConsumer(Operation *op, OpBuilder &b,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes) const {
+ /// Method to return the tiled implementation of tensor.unpack as a consumer.
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
auto unPackOp = cast<UnPackOp>(op);
+ // tensor.unpack op is fusible (as a consumer) only if inner dims are not
+ // tiled.
+ int64_t numTiles = unPackOp.getInnerDimsPos().size();
+ for (auto iter :
+ llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
+ if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
+ return failure();
+ }
+
Location loc = unPackOp.getLoc();
// Fetch offset/size for creating the slice of the dest operand of
@@ -557,27 +568,6 @@ struct UnPackOpTiling
return TilingResult{{tiledUnPackOp},
SmallVector<Value>(tiledUnPackOp->getResults())};
}
-
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
- auto unPackOp = cast<UnPackOp>(op);
- // tensor.unpack op is fusible (as a consumer) only if inner dims are not
- // tiled.
- int64_t numTiles = unPackOp.getInnerDimsPos().size();
- for (auto iter :
- llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
- if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
- return failure();
- }
-
- FailureOr<TilingResult> tilingResult =
- getTiledImplementationAsConsumer(unPackOp, b, offsets, sizes);
-
- if (failed(tilingResult))
- return failure();
- return tilingResult.value();
- }
};
} // namespace
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 181d8ebc68f9e..833fb3cc65b81 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -183,8 +183,10 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
return failure();
// Report back the relevant handles to the transform op.
- originalConsumerOps.push_back(fuseConsumerResults->origConsumer);
- fusedConsumerOps.push_back(fuseConsumerResults->tiledAndFusedConsumer);
+ originalConsumerOps.push_back(
+ fuseConsumerResults->origConsumerOperand->getOwner());
+ fusedConsumerOps.push_back(
+ fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
}
transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
>From ac0c1f5ba7a859ce9bc2cf7f24b71d995460e31c Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 20 May 2024 07:42:43 +0000
Subject: [PATCH 07/10] getTiledImplementationFromOperandTile as static method
---
mlir/include/mlir/Interfaces/TilingInterface.td | 3 ++-
.../lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 4 ++--
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp | 8 ++++----
.../Transforms/SwapExtractSliceWithProducerPatterns.cpp | 4 ++--
4 files changed, 10 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index fbd1021ce67fe..df93b98c278fa 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -150,7 +150,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
- InterfaceMethod<
+ StaticInterfaceMethod<
/*desc=*/[{
Method to generate the tiled implementation of an operation from
operand tile position.
@@ -177,6 +177,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
/*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementationFromOperandTile",
/*args=*/(ins
+ "Operation*":$op,
"OpBuilder &":$b,
"unsigned":$operandNumber,
"ArrayRef<OpFoldResult>":$offsets,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 093c8b6347ea7..9be2946cdb57a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -217,9 +217,9 @@ struct LinalgOpTilingInterface
/// Method to generate the tiled implementation of an operation from operand
/// tile position.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ static FailureOr<TilingResult> getTiledImplementationFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
auto tilingInterfaceOp = cast<TilingInterface>(op);
if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 3518da839f9d2..458fbb30fa70f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -519,9 +519,9 @@ struct UnPackOpTiling
}
/// Method to return the tiled implementation of tensor.unpack as a consumer.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ static FailureOr<TilingResult> getTiledImplementationFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) {
auto unPackOp = cast<UnPackOp>(op);
// tensor.unpack op is fusible (as a consumer) only if inner dims are not
// tiled.
@@ -537,8 +537,8 @@ struct UnPackOpTiling
// Fetch offset/size for creating the slice of the dest operand of
// unpack op.
SmallVector<OpFoldResult> outputOffsets, outputSizes;
- if (failed(getIterationDomainTileFromOperandTile(
- op, b, 0, offsets, sizes, outputOffsets, outputSizes)))
+ if (failed(cast<TilingInterface>(op).getIterationDomainTileFromOperandTile(
+ b, 0, offsets, sizes, outputOffsets, outputSizes)))
return failure();
auto oneAttr = b.getI64IntegerAttr(1);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 858adfc436164..b1aff5618d604 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -56,8 +56,8 @@ FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
FailureOr<TilingResult> tiledResult =
consumerOp.getTiledImplementationFromOperandTile(
- builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes());
+ consumerOp, builder, consumer.getOperandNumber(),
+ sliceOp.getMixedOffsets(), sliceOp.getMixedSizes());
if (failed(tiledResult))
return failure();
>From 41123a2871c752cdbbdece54cd7fb11a9a95c768 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 20 May 2024 07:50:07 +0000
Subject: [PATCH 08/10] Address review comments v2
---
.../SCF/Transforms/TileUsingInterface.cpp | 77 ++++++++++---------
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 3 +-
.../tile-and-fuse-consumer.mlir | 15 ++--
3 files changed, 48 insertions(+), 47 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 8064e6f524327..07248c65e85aa 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1238,11 +1238,12 @@ static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
/// After fusing consumer into scf.for we want to modify the scf.yield operation
/// to reflect the same by returning the values yielded by the tiled consumer.
-static void fixTerminatorSCFYield(
- RewriterBase &rewriter, scf::ForOp newForOp, TilingResult tilingResult,
- SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
- SmallVector<SmallVector<OpFoldResult>> &resultSizes,
- SmallVector<OpFoldResult> &strides, ArrayRef<BlockArgument> bbArgs) {
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+ TilingResult &tilingResult,
+ ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ArrayRef<BlockArgument> bbArgs) {
scf::YieldOp oldTerminatorOp =
cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
unsigned totalOldResults = oldTerminatorOp->getNumResults();
@@ -1254,12 +1255,13 @@ static void fixTerminatorSCFYield(
}
rewriter.setInsertionPointAfter(oldTerminatorOp);
Location loc = newForOp.getLoc();
- for (auto [idx, v] :
- llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
- SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+ for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+ llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
+ resultOffsets, resultSizes)) {
+ SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, v, bbArgs[idx], resultOffsets[idx], resultSizes[idx], strides);
+ loc, tiledResult, bbArg, resultOffset, resultSize, strides);
newYieldOperands.push_back(newInsertSliceOp);
}
rewriter.create<scf::YieldOp>(loc, newYieldOperands);
@@ -1268,23 +1270,22 @@ static void fixTerminatorSCFYield(
/// After fusing consumer into scf.forall we want to yield each of the resulting
/// values by the tiled consumer within scf.forall.in_parallel region.
-static void fixTerminatorSCFInParallel(
- RewriterBase &rewriter, scf::ForallOp newForallOp,
- TilingResult tilingResult,
- SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
- SmallVector<SmallVector<OpFoldResult>> &resultSizes,
- SmallVector<OpFoldResult> &strides, ArrayRef<BlockArgument> bbArgs) {
+static void
+fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
+ SmallVector<Value> tiledResults,
+ ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ArrayRef<BlockArgument> bbArgs) {
scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
Location firstYieldOpLoc =
(*(newTerminatorOp.getYieldingOps().begin())).getLoc();
- for (auto [idx, v] :
- llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
- SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+ for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+ llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+ SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
- firstYieldOpLoc, v, bbArgs[idx], resultOffsets[idx], resultSizes[idx],
- strides);
+ firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
}
}
@@ -1324,17 +1325,15 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
unsigned initSize = 0;
unsigned rank = 1;
if (isInsertSliceOp) {
- auto forOp = candidateSliceOp->template getParentOfType<scf::ForOp>();
- SmallVector<Value> forOpOuts(forOp.getInits());
+ auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
oldLoopOp = forOp;
- newOuts = forOpOuts;
+ llvm::append_range(newOuts, forOp.getInits());
oldLoopBody = forOp.getBody();
initSize = forOp.getInits().size();
} else {
- auto forallOp = candidateSliceOp->template getParentOfType<scf::ForallOp>();
- SmallVector<Value> forallOpOuts(forallOp.getOutputs());
+ auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
oldLoopOp = forallOp;
- newOuts = forallOpOuts;
+ llvm::append_range(newOuts, forallOp.getOutputs());
oldLoopBody = forallOp.getBody();
initSize = forallOp.getOutputs().size();
rank = forallOp.getRank();
@@ -1407,7 +1406,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
}
});
- // 6 - Perform tiling of the cloned consumer.
+ // 6 - Perform tiling of the cloned consumer and replace the OpOperand that's
+ // already tiled.
if (isInsertSliceOp) {
rewriter.setInsertionPointAfter(clonedConsumerOp);
} else {
@@ -1418,9 +1418,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
tensor::replaceInsertSliceWithTiledConsumer(
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
if (failed(tileAndFuseResult)) {
- return rewriter.notifyMatchFailure(clonedConsumerOp,
- "failed to tile consumer op: ");
+ return failure();
}
+ tileAndFuseResult->tiledOps[0]
+ ->getOpOperand(operandNumber)
+ .set(candidateSliceOp->getOperand(0));
// 7 - Extract offset/sizes/strides required to create the
// tensor.insert_slice/parallel_insert_slice for each result of the consumer.
@@ -1468,16 +1470,18 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
}
}
+ auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
+ auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
if (isInsertSliceOp) {
auto newForOp = cast<scf::ForOp>(newLoopOp);
fixTerminatorSCFYield(
- rewriter, newForOp, *tileAndFuseResult, resultOffsets, resultSizes,
- strides, newForOp.getBody()->getArguments().drop_front(1 + initSize));
+ rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
+ newForOp.getBody()->getArguments().drop_front(1 + initSize));
} else {
auto newForallOp = cast<scf::ForallOp>(newLoopOp);
fixTerminatorSCFInParallel(
- rewriter, newForallOp, *tileAndFuseResult, resultOffsets, resultSizes,
- strides,
+ rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
+ arrayRefOffsets, arrayRefSizes,
newForallOp.getBody()->getArguments().drop_front(rank + initSize));
}
@@ -1487,9 +1491,10 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
rewriter.replaceAllUsesWith(oldResult, newResult);
}
- for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
- rewriter.replaceAllUsesWith(oldValue,
- newLoopOp->getResult(initSize + index));
+ for (auto &&[oldResult, newResult] :
+ llvm::zip(consumerOp->getResults(),
+ newLoopOp->getResults().drop_front(initSize))) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
}
// 13. Need to erase the old scf loop and the cloned consumer op.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 458fbb30fa70f..51350061d23ee 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -538,7 +538,8 @@ struct UnPackOpTiling
// unpack op.
SmallVector<OpFoldResult> outputOffsets, outputSizes;
if (failed(cast<TilingInterface>(op).getIterationDomainTileFromOperandTile(
- b, 0, offsets, sizes, outputOffsets, outputSizes)))
+ b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
+ outputSizes)))
return failure();
auto oneAttr = b.getI64IntegerAttr(1);
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index f2b64a0d54438..400b558e37fcd 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -46,11 +46,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[MAT_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
-// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[INSERT_MAT]][%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
-// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
// CHECK-SAME: outs(%[[SLICE_OUT]] :
// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
@@ -104,11 +103,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[MAT_OUT:.*]] = linalg.matmul
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
-// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
-// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
// CHECK-SAME: outs(%[[SLICE_OUT]] :
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
@@ -173,12 +171,11 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[MAT_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
-// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[INSERT_MAT]][%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
-// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
// CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
// CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
@@ -244,12 +241,11 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[MAT_OUT:.*]] = linalg.matmul
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
-// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
-// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
@@ -310,8 +306,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
// CHECK: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_MAP]](%[[IV1]])
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
-// CHECK: %[[TILED_UNPACK_SRC:.*]] = tensor.extract_slice %[[GENERIC_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
-// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[TILED_UNPACK_SRC]]
+// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
// CHECK: scf.forall.in_parallel {
>From 86b83cf3d78b861c6ce8c68ed746529ba062f082 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 21 May 2024 06:52:57 +0000
Subject: [PATCH 09/10] getTiledImplementationFromOperandTile with default
implementation
---
mlir/include/mlir/Interfaces/TilingInterface.td | 11 ++++++++---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 15 ---------------
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 8 ++++----
.../SwapExtractSliceWithProducerPatterns.cpp | 4 ++--
4 files changed, 14 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index df93b98c278fa..cece4ec3ff05f 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -150,7 +150,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
- StaticInterfaceMethod<
+ InterfaceMethod<
/*desc=*/[{
Method to generate the tiled implementation of an operation from
operand tile position.
@@ -177,14 +177,19 @@ def TilingInterface : OpInterface<"TilingInterface"> {
/*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementationFromOperandTile",
/*args=*/(ins
- "Operation*":$op,
"OpBuilder &":$b,
"unsigned":$operandNumber,
"ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return failure();
+ ::llvm::SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ auto tilingInterfaceOp = cast<::mlir::TilingInterface>($_op.getOperation());
+ if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return failure();
+ }
+ return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
}]
>,
InterfaceMethod<
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 9be2946cdb57a..7bef868d05ce5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -215,21 +215,6 @@ struct LinalgOpTilingInterface
return success();
}
- /// Method to generate the tiled implementation of an operation from operand
- /// tile position.
- static FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) {
- SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
- auto tilingInterfaceOp = cast<TilingInterface>(op);
- if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
- b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
- return failure();
- }
- return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
- mappedSizes);
- }
-
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 51350061d23ee..9b2a97eb2b006 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -519,9 +519,9 @@ struct UnPackOpTiling
}
/// Method to return the tiled implementation of tensor.unpack as a consumer.
- static FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) {
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
auto unPackOp = cast<UnPackOp>(op);
// tensor.unpack op is fusible (as a consumer) only if inner dims are not
// tiled.
@@ -537,8 +537,8 @@ struct UnPackOpTiling
// Fetch offset/size for creating the slice of the dest operand of
// unpack op.
SmallVector<OpFoldResult> outputOffsets, outputSizes;
- if (failed(cast<TilingInterface>(op).getIterationDomainTileFromOperandTile(
- b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
+ if (failed(getIterationDomainTileFromOperandTile(
+ op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
outputSizes)))
return failure();
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index b1aff5618d604..858adfc436164 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -56,8 +56,8 @@ FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
FailureOr<TilingResult> tiledResult =
consumerOp.getTiledImplementationFromOperandTile(
- consumerOp, builder, consumer.getOperandNumber(),
- sliceOp.getMixedOffsets(), sliceOp.getMixedSizes());
+ builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes());
if (failed(tiledResult))
return failure();
>From 7e9f0b597e8fd378ba7c9fc572ba7b62070eda79 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 22 May 2024 08:10:15 +0000
Subject: [PATCH 10/10] Better algo by Mahesh
---
.../SCF/Transforms/TileUsingInterface.cpp | 66 ++++++++++---------
1 file changed, 34 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 07248c65e85aa..f043e0f59feb0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1386,51 +1386,59 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
rewriter.mergeBlocks(oldLoopBody, newLoopBody,
newLoopBody->getArguments().take_front(oldNumArguments));
- // 5.a. Clone consumer after the cloned
- // tensor.insert_slice/parallel_insert_slice op.
- rewriter.setInsertionPointAfter(candidateSliceOp);
+ // 5. Set insertion point before terminator op of the loop and create a new
+ // tensor.insert_slice. In the scf.for case this is a clone of the
+ // candidateSliceOp whereas in the scf.forall case this is created from the
+ // operands of tensor.parallel_insert_slice.
+ tensor::InsertSliceOp clonedInsertSliceOp;
+ if (auto sliceOp =
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+ auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+ rewriter.setInsertionPoint(newForallOp.getTerminator());
+ clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+ } else {
+ auto newForOp = cast<scf::ForOp>(newLoopOp);
+ rewriter.setInsertionPoint(newForOp.getBody()->getTerminator());
+ clonedInsertSliceOp =
+ cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
+ }
+
+ // 6.a. Clone consumer op.
auto newForOpBlockArgsForConsumerDest =
newLoopBody->getArguments().drop_front(oldNumArguments);
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
- // 5.b. Replace all uses of the loop result with the result of the cloned
- // tensor.insert_slice/parallel_insert_slice.
+ // 6.b. Replace all uses of the loop result with the result of the cloned
+ // tensor.insert_slice.
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
- if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
- operandToReplace.set(sliceOp.getResult());
- } else if (auto sliceOp =
- dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
- operandToReplace.set(sliceOp.getSource());
- }
+ operandToReplace.set(clonedInsertSliceOp.getResult());
});
- // 6 - Perform tiling of the cloned consumer and replace the OpOperand that's
- // already tiled.
- if (isInsertSliceOp) {
- rewriter.setInsertionPointAfter(clonedConsumerOp);
- } else {
- rewriter.setInsertionPoint(cast<scf::ForallOp>(newLoopOp).getTerminator());
- }
- auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
+ // 7 - Perform tiling of the cloned consumer and replace the operand at
+ // `operandNumber` with the source of the cloned tensor.insert_slice op.
+ auto ossSliceOp =
+ cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
if (failed(tileAndFuseResult)) {
return failure();
}
- tileAndFuseResult->tiledOps[0]
- ->getOpOperand(operandNumber)
- .set(candidateSliceOp->getOperand(0));
+ rewriter.replaceAllUsesWith(
+ tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
+ clonedInsertSliceOp.getSource());
- // 7 - Extract offset/sizes/strides required to create the
+ // 8 - Extract offset/sizes/strides required to create the
// tensor.insert_slice/parallel_insert_slice for each result of the consumer.
SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
- // 8. Check all insert stride is 1.
+ // 9. Check all insert stride is 1.
if (llvm::any_of(strides, [](OpFoldResult stride) {
return !isConstantIntValue(stride, 1);
})) {
@@ -1438,14 +1446,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
candidateSliceOp, "containingOp's result yield with stride");
}
- // 9. Try to get iter domain position from input position.
+ // 10. Try to get iter domain position from input position.
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-
- if (isInsertSliceOp) {
- rewriter.setInsertionPointAfter(clonedConsumerOp);
- } else {
- rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
- }
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
iterDomainSizes))) {
@@ -1453,7 +1455,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
clonedConsumerOp, "can't get iter domain position from input position");
}
- // 10. Try to fetch the offset and size for all results of the cloned
+ // 11. Try to fetch the offset and size for all results of the cloned
// consumer. This would then be used to form the corresponding
// tensor.insert_slice/parallel_insert_slice later.
unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
More information about the Mlir-commits
mailing list