[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
Abhishek Varma
llvmlistbot at llvm.org
Fri Apr 19 04:31:18 PDT 2024
https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/88712
>From 045db97bf5ecd34b273328f67a91375ffff9cfa0 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 10 Apr 2024 10:41:46 +0000
Subject: [PATCH 1/2] [MLIR][SCF] Add an API to fuse consumer to a producer
within scf loop
-- This commit adds an API to fuse consumer to a producer within
scf.for/scf.forall loop.
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
.../SCF/Transforms/TileUsingInterface.h | 13 +
.../Dialect/Tensor/Transforms/Transforms.h | 13 +-
.../mlir/Interfaces/TilingInterface.td | 55 +++
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 96 +++-
.../SCF/Transforms/TileUsingInterface.cpp | 424 ++++++++++++++++++
.../SwapExtractSliceWithProducerPatterns.cpp | 23 +
.../TilingInterface/fuse-consumer.mlir | 119 +++++
.../TestTilingInterfaceTransformOps.cpp | 52 +++
.../TestTilingInterfaceTransformOps.td | 19 +
9 files changed, 792 insertions(+), 22 deletions(-)
create mode 100644 mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be28..b51947c6dadef4 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -14,6 +14,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include <deque>
@@ -239,6 +240,18 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer,
const SCFTileAndFuseOptions &options);
+/// Fuse the consumer of the source of `candidateSliceOp` by computing the
+/// required slice of the consumer in-place. Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
+/// value but does not delete the slice operation.
+struct SCFFuseConsumerOfSliceResult {
+ Operation *origConsumer; // Original untiled consumer.
+ Value tiledAndFusedConsumer; // Tile and fused consumer value.
+ SmallVector<Operation *> tiledOps;
+};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+
/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
FailureOr<SmallVector<scf::ForOp>>
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index e8a09c4741043b..98447cf62900d5 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
@@ -22,7 +23,7 @@ namespace tensor {
// Patterns
//===----------------------------------------------------------------------===//
-/// Pattern to swap an `tensor.extract_slice` with its producer when the
+/// Method to swap an `tensor.extract_slice` with its producer when the
/// producer implements the `TilingInterface`. The pattern itself does not
/// provide a mechanism to control where the application happens. With use of
/// transform dialect that control is done within the transform dialect. Other
@@ -30,6 +31,16 @@ namespace tensor {
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
+/// Method to swap an `tensor.insert_slice` with its consumer when the
+/// consumer implements the `TilingInterface`. The pattern itself does not
+/// provide a mechanism to control where the application happens. With use of
+/// transform dialect that control is done within the transform dialect. Other
+/// use cases can inherit from this pattern and add necessary controls.
+FailureOr<TilingResult>
+replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
+ OffsetSizeAndStrideOpInterface sliceOp,
+ OpOperand &consumerOp);
+
//===----------------------------------------------------------------------===//
// Populate functions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return {};
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return iterator domain position computed by the
+ input operand position.
+ }],
+ /*retType=*/"LogicalResult",
+ /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVector<OpFoldResult> &":$iterDomainSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate the tiled implementation of an operation from
+ operand position.
+
+ Generates the IR that generate the tiled implementation of an
+ operation from operand position. The `offsets` and `sizes`
+ describe the tile of the operand required. This is different from
+ `getTiledImplementation` which generates the tiled
+ implementation of the operation given a tile of the
+ iteration space. This method generates a tiled
+ implementation of the operation based on the position of the
+ operand required. This method enables fusion consumer by using
+ tile and fuse. The method returns failure if the operation
+ can't be tiled to generate the operand tile. In practical terms
+ this implies it cannot be tiled and fused with its producers.
+
+ - `offsets` provides the offset of the tile in the coordinate system
+ of the original iteration space, i.e., if an iteration space
+ dimension had non-zero offset, it must be included in the offset
+ provided here (as opposed to zero-based offset "relative" to the
+ iteration space).
+ - `sizes` provides the size of the tile.
+ }],
+ /*retType=*/"FailureOr<TilingResult>",
+ /*methodName=*/"getTiledImplementationFromOperandPosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult>":$offsets,
+ "ArrayRef<OpFoldResult>":$sizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..01bf19764b0938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
+ void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
+ AffineMap indexingMap,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &mappedOffsets,
+ SmallVector<OpFoldResult> &mappedSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+ auto numLoops = linalgOp.getNumLoops();
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ mappedOffsets.resize(numLoops);
+ mappedSizes.resize(numLoops);
+ if (!indexingMap.isPermutation()) {
+ SmallVector<Range> iterationDomain =
+ tilingInterfaceOp.getIterationDomain(b);
+ for (const auto &range : llvm::enumerate(iterationDomain)) {
+ mappedOffsets[range.index()] = range.value().offset;
+ mappedSizes[range.index()] = range.value().size;
+ }
+ }
+ for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+ unsigned dimPosition =
+ cast<AffineDimExpr>(resultExpr.value()).getPosition();
+ mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+ mappedSizes[dimPosition] = sizes[resultExpr.index()];
+ }
+ }
+
+ // Return the details of the output tile generated by the tiled
+ // implementation.
+ LogicalResult getIterDomainTilePositionFromOperandPosition(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &iterDomainOffsets,
+ SmallVector<OpFoldResult> &iterDomainSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+
+ // Check that the indexing map used for the operand is a projected
+ // permutation. This could be relaxed with a more general approach that can
+ // map the offsets and sizes from the operand to iteration space tiles
+ // (filling in full extent for dimensions not used to access the result).
+ AffineMap indexingMap =
+ linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+ if (!indexingMap.isProjectedPermutation()) {
+ return op->emitOpError(
+ "unhandled get iter domain position when operand is not "
+ "accessed using a permuted projection");
+ }
+
+ getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes,
+ iterDomainOffsets, iterDomainSizes);
+ return success();
+ }
+
// Return the details of the output tile generated by the tiled
// implementation.
LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
return success();
}
+ FailureOr<TilingResult> getTiledImplementationFromOperandPosition(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ if (failed(tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition(
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return op->emitOpError(
+ "unable to obtain the iter domain position of the operation.");
+ }
+ return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+ mappedSizes);
+ }
+
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
-
- auto numLoops = linalgOp.getNumLoops();
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes, mappedOffsets,
+ mappedSizes);
auto tilingInterfaceOp = cast<TilingInterface>(op);
- SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
- iterationTileSizes(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &range : llvm::enumerate(iterationDomain)) {
- iterationTileOffsets[range.index()] = range.value().offset;
- iterationTileSizes[range.index()] = range.value().size;
- }
- }
- for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition =
- cast<AffineDimExpr>(resultExpr.value()).getPosition();
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
- iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
- }
-
FailureOr<TilingResult> tilingResult =
- tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
- iterationTileSizes);
+ tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+ if (failed(tilingResult))
+ return failure();
+
if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69df..625586edc1807f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -1100,6 +1101,429 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+ Value::use_range uses = result.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return failure();
+ }
+ OpOperand &operandUse = (*uses.begin());
+ Operation *userOp = operandUse.getOwner();
+ if (!isa<scf::YieldOp>(userOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Expected scf.yield to be the only user, but got -> "
+ << (*userOp));
+ return failure();
+ }
+ return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1. tensor.insert_slice has scf.yield as its only user.
+/// 2. scf.for's correspon
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+ Value sliceResult = candidateSliceOp.getResult();
+ if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+ return failure();
+ }
+ // Step 1. Fetch the corresponding output.
+ OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+ unsigned resultNumber = yieldOpOperand.getOperandNumber();
+ // Check containing op is "scf::ForOp".
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp) {
+ return failure();
+ }
+ Value resultingValue = forOp->getResult(resultNumber);
+
+ // Check resultingValue has exactly one use.
+ if (!llvm::hasSingleElement(resultingValue.getUses())) {
+ return failure();
+ }
+
+ // Step 2. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+ tensor::InsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the source.
+ FailureOr<OpOperand *> consumerOpOperand =
+ getUntiledConsumerFromSlice(candidateSliceOp);
+ if (failed(consumerOpOperand)) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = (*consumerOpOperand)->getOwner();
+ unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+ unsigned resultNumber =
+ cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+ // Check that the consumer results in exactly one value.
+ // TODO: Support fusion for consumers yielding more than one result.
+ if (consumerOp->getResults().size() != 1) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "only those consumers returning exactly one result are supported");
+ }
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ // Check containing op is "scf::ForOp".
+ auto forOp = static_cast<scf::ForOp>(containingOp);
+ // if (!forOp) {
+ // return rewriter.notifyMatchFailure(containingOp,
+ // "containing op is not a scf.for");
+ // }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(candidateSliceOp);
+
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer is not a TileableInterface");
+ }
+
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ return rewriter.notifyMatchFailure(
+ consumerOp, "consumer op should have destination style op interface");
+ }
+
+ // Check consumer is not using scf.for's output as init.
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, forOp.getResult(0))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.for as init is not supported");
+ }
+
+ Location loc = forOp.getLoc();
+ SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
+
+ SmallVector<Value> newOuts(forOp.getInits());
+ newOuts.append(dpsInits);
+
+ // Create new scf.for op.
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+ forOp.getUpperBound(),
+ forOp.getStep(), newOuts);
+ // Move the loop body to the new op.
+ Block *loopBody = forOp.getBody();
+ Block *newLoopBody = newforOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ // Clone the consumer after the insert_slice.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ SmallVector<Value> newForOpBlockArgsForConsumerDest;
+ for (unsigned i = loopBody->getNumArguments(),
+ n = newLoopBody->getArguments().size();
+ i < n; i++) {
+ newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
+ }
+ auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // Replace scf.for result's use in the cloned consumer with insert_slice
+ // result.
+ rewriter.replaceUsesWithIf(forOp.getResult(resultNumber),
+ candidateSliceOp.getResult(),
+ [&](OpOperand &operand) {
+ return operand.getOwner() == clonedConsumerOp;
+ });
+
+ // Generate the tiled implementation of the consumer of the source.
+ rewriter.setInsertionPoint(candidateSliceOp);
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter,
+ cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
+ clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return rewriter.notifyMatchFailure(tileableConsumer,
+ "failed to tile consumer op: ");
+ }
+
+ // Update the source of the candidateSlice to be the cloned consumer.
+ SmallVector<Value> candidateSliceOpOperands =
+ llvm::to_vector(candidateSliceOp->getOperands());
+ candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
+ auto bbArgs = newforOp.getBody()->getArguments();
+ candidateSliceOpOperands[1] = bbArgs[1 + forOp.getInits().size() + 0];
+ tensor::InsertSliceOp clonedCandidateSliceOp =
+ mlir::clone(rewriter, candidateSliceOp,
+ candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+ rewriter.replaceAllUsesWith(candidateSliceOp, candidateSliceOp.getSource());
+ rewriter.eraseOp(clonedConsumerOp);
+
+ // Fix terminator.
+ scf::YieldOp oldTerminatorOp =
+ static_cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+ // llvm::outs()<<"\n========= DB - 5 ===========\n"<<funcOp<<"\n";
+
+ SmallVector<Value> newYieldOperands;
+ for (Value val : oldTerminatorOp.getResults()) {
+ if (val == candidateSliceOp.getSource()) {
+ newYieldOperands.push_back(candidateSliceOp.getResult());
+ } else {
+ newYieldOperands.push_back(val);
+ }
+ }
+ newYieldOperands.push_back(clonedCandidateSliceOp.getResult());
+ rewriter.setInsertionPointAfter(oldTerminatorOp);
+ rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+ rewriter.eraseOp(oldTerminatorOp);
+
+ // Replace the result of for and consumer op.
+ for (auto result : llvm::enumerate(forOp.getResults())) {
+ rewriter.replaceAllUsesWith(result.value(),
+ newforOp->getResult(result.index()));
+ }
+
+ for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ consumerResult.value(),
+ newforOp->getResult(forOp.getInits().size() + consumerResult.index()));
+ }
+
+ // Need to erase the old for.
+ rewriter.eraseOp(forOp);
+ rewriter.eraseOp(consumerOp);
+
+ return scf::SCFFuseConsumerOfSliceResult{
+ consumerOp, tileAndFuseResult->tiledOps[0]->getResult(0), {}};
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // Step 1. Fetch the corresponding output
+ Value sliceDest = candidateSliceOp.getDest();
+ auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+ Operation *containingOp = iterArg.getOwner()->getParentOp();
+ // Step 2. Check assumptions for the containing op.
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ return failure();
+ }
+ unsigned resultNumber = 0;
+ for (BlockArgument val : forallOp.getRegionOutArgs()) {
+ if (val == iterArg) {
+ break;
+ }
+ resultNumber++;
+ }
+ Value resultingValue = forallOp->getResult(resultNumber);
+ Value::use_range uses = resultingValue.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ return failure();
+ }
+
+ // Step 3. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+ RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the dest.
+ FailureOr<OpOperand *> consumerOpOperand =
+ getUntiledConsumerFromSlice(candidateSliceOp);
+ if (failed(consumerOpOperand)) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = (*consumerOpOperand)->getOwner();
+ unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+ unsigned resultNumber =
+ cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+ // Check that the consumer results in exactly one value.
+ // TODO: Support fusion for consumers yielding more than one result.
+ if (consumerOp->getResults().size() != 1) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "only those consumers returning exactly one result are supported");
+ }
+ OpBuilder::InsertionGuard g(rewriter);
+ // Using candidateSliceOp->getParentOp() because we have the following case :-
+ // scf.forall.in_parallel {
+ // tensor.parallel_insert_slice ...
+ // }
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+ Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer is not a TileableInterface");
+ }
+
+ auto forallOp = static_cast<scf::ForallOp>(containingOp);
+ // TODO: We have to init result of consumer before scf.forall, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ return rewriter.notifyMatchFailure(
+ consumerOp, "consumer op should have destination style op interface");
+ }
+
+ // Check consumer doesn't use scf.forall's output as init.
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.forall as init is not supported");
+ }
+
+ SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
+
+ Location loc = forallOp.getLoc();
+ // Create new scf.forall op.
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.append(dpsInits);
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+ // Move the loop body to the new op.
+ rewriter.eraseOp(newforallOp.getTerminator());
+ Block *loopBody = forallOp.getBody();
+ Block *newLoopBody = newforallOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ // Clone the consumer after the parallel_insert_slice.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ SmallVector<Value> newForOpBlockArgsForConsumerDest;
+ for (unsigned i = loopBody->getNumArguments(),
+ n = newLoopBody->getArguments().size();
+ i < n; i++) {
+ newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
+ }
+ auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // Replace scf.forall result's use in the consumer with parallel_insert_slice
+ // source.
+ rewriter.replaceAllUsesWith(forallOp.getResult(resultNumber),
+ candidateSliceOp.getSource());
+
+ // Generate the tiled implementation of the consumer of the source.
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter,
+ cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
+ clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return rewriter.notifyMatchFailure(tileableConsumer,
+ "failed to tile consumer op: ");
+ }
+
+ // Update the source of the candidateSlice to be the cloned consumer.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ SmallVector<Value> candidateSliceOpOperands =
+ llvm::to_vector(candidateSliceOp->getOperands());
+ candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
+ auto bbArgs = newforallOp.getBody()->getArguments();
+ candidateSliceOpOperands[1] =
+ bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + 0];
+ tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
+ mlir::clone(rewriter, candidateSliceOp,
+ candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+ LLVM_DEBUG(llvm::dbgs() << "Created a clone of the candidate slice op : "
+ << clonedCandidateSliceOp << "\n");
+
+ rewriter.eraseOp(clonedConsumerOp);
+
+ // Replace the result of scf.forall and consumer op.
+ for (auto result : llvm::enumerate(forallOp.getResults())) {
+ rewriter.replaceAllUsesWith(result.value(),
+ newforallOp->getResult(result.index()));
+ }
+
+ for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ consumerResult.value(),
+ newforallOp->getResult(forallOp.getOutputs().size() +
+ consumerResult.index()));
+ }
+
+ // Need to erase the old scf.forall and consumer.
+ rewriter.eraseOp(forallOp);
+ rewriter.eraseOp(consumerOp);
+
+ return scf::SCFFuseConsumerOfSliceResult{
+ consumerOp, tileAndFuseResult->tiledOps[0]->getResult(0), {}};
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+ Operation *candidateSliceOp) {
+ if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+ return tileAndFuseConsumerOfSliceSCFFor(rewriter, sliceOp);
+ } else if (auto sliceOp =
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+ return tileAndFuseConsumerOfSliceSCFForall(rewriter, sliceOp);
+ } else {
+ return failure();
+ }
+}
+
//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 40d79c20538172..6da2abd7ae8449 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -40,3 +40,26 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
return *tiledResult;
}
+
+FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
+ OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
+ OpOperand &consumer) {
+ auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
+ if (!consumerOp)
+ return failure();
+
+ // `TilingInterface` currently only supports strides being 1.
+ if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+ return !isConstantIntValue(ofr, 1);
+ }))
+ return failure();
+
+ FailureOr<TilingResult> tiledResult =
+ consumerOp.getTiledImplementationFromOperandPosition(
+ builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes());
+ if (failed(tiledResult))
+ return failure();
+
+ return *tiledResult;
+}
diff --git a/mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir
new file mode 100644
index 00000000000000..9eabb385303515
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+module {
+ func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32xf32>
+ %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+ scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+ }
+ %in_operand_2 = tensor.empty() : tensor<64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64xf32>
+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
+ return %2 : tensor<64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %yield
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_for(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %0 = tensor.empty() : tensor<64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT]] :
+// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+module {
+ func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ }
+ }
+ %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ return %2 : tensor<64x64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+ : (!transform.any_op)
+ -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %first_slice_op
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_forall(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
+// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#2 :
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 335db1a61f476e..dd7203af54ad2a 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -160,6 +160,58 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
: DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// TestFuseConsumerOp
+//===----------------------------------------------------------------------===//
+
+/// Apply fusing of consumer transformation to all payload ops and store both
+/// the original consumer operation as well as the fused consumer operation.
+template <typename Range>
+static LogicalResult
+applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
+ Range &&payloadOps, TransformResults &transformResults) {
+ SmallVector<Operation *> originalConsumerOps;
+ SmallVector<Operation *> fusedConsumerOps;
+
+ for (Operation *target : payloadOps) {
+ rewriter.setInsertionPoint(target);
+
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+ scf::tileAndFuseConsumerOfSlice(rewriter, target);
+
+ if (failed(fuseConsumerResults))
+ return failure();
+
+ // Report back the relevant handles to the transform op.
+ originalConsumerOps.push_back(fuseConsumerResults->origConsumer);
+ fusedConsumerOps.push_back(
+ fuseConsumerResults->tiledAndFusedConsumer.getDefiningOp());
+ }
+
+ transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
+ transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
+ return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ LogicalResult result =
+ applyFuseConsumer(rewriter, getOperation(),
+ state.getPayloadOps(getTarget()), transformResults);
+ return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+ : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseConsumerOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ producesHandle(getConsumer(), effects);
+ producesHandle(getFusedConsumer(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// TestTileUsingForallOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index ef42375e5286d8..d55d746bd6aa90 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
}];
}
+def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Fuses the consumer of the operation pointed to by the target handle
+ using the options provided as attributes.
+ }];
+
+ let arguments =
+ (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$consumer,
+ TransformHandleTypeInterface:$fused_consumer);
+
+ let assemblyFormat = [{
+ $target attr-dict `:` functional-type(operands, results)
+ }];
+}
+
def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
>From 4fd26b60c142b77f5dedbf6b8aeda64dd48f9590 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 19 Apr 2024 10:09:12 +0000
Subject: [PATCH 2/2] Address algo related comments
---
.../SCF/Transforms/TileUsingInterface.cpp | 265 +++++++++++-------
.../TilingInterface/fuse-consumer.mlir | 145 +++++++++-
2 files changed, 298 insertions(+), 112 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 625586edc1807f..0d57471e587cdd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -1173,20 +1174,8 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
unsigned resultNumber =
cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
- // Check that the consumer results in exactly one value.
- // TODO: Support fusion for consumers yielding more than one result.
- if (consumerOp->getResults().size() != 1) {
- return rewriter.notifyMatchFailure(
- consumerOp,
- "only those consumers returning exactly one result are supported");
- }
Operation *containingOp = candidateSliceOp->getParentOp();
- // Check containing op is "scf::ForOp".
auto forOp = static_cast<scf::ForOp>(containingOp);
- // if (!forOp) {
- // return rewriter.notifyMatchFailure(containingOp,
- // "containing op is not a scf.for");
- // }
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(candidateSliceOp);
@@ -1218,17 +1207,6 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
}
Location loc = forOp.getLoc();
- SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
- // Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
-
SmallVector<Value> newOuts(forOp.getInits());
newOuts.append(dpsInits);
@@ -1244,69 +1222,99 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
loopBody, newLoopBody,
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
- // Clone the consumer after the insert_slice.
+ // 1 - Clone tensor.insert_slice after original tensor.insert_slice.
rewriter.setInsertionPointAfter(candidateSliceOp);
- SmallVector<Value> newForOpBlockArgsForConsumerDest;
- for (unsigned i = loopBody->getNumArguments(),
- n = newLoopBody->getArguments().size();
- i < n; i++) {
- newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
- }
+ SmallVector<Value> candidateSliceOpOperands =
+ llvm::to_vector(candidateSliceOp->getOperands());
+ tensor::InsertSliceOp clonedCandidateSliceOp =
+ mlir::clone(rewriter, candidateSliceOp,
+ candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+ // 2.a - Clone consumer after the cloned tensor.insert_slice op.
+ rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
+ SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
+ newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
+ [](BlockArgument b) -> Value { return b; });
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+ tileableConsumer = clonedConsumerOp;
- // Replace scf.for result's use in the cloned consumer with insert_slice
- // result.
+ // 2.b - Replace all uses of the loop result with the result of the cloned
+ // tensor.insert_slice.
rewriter.replaceUsesWithIf(forOp.getResult(resultNumber),
- candidateSliceOp.getResult(),
+ clonedCandidateSliceOp.getResult(),
[&](OpOperand &operand) {
return operand.getOwner() == clonedConsumerOp;
});
- // Generate the tiled implementation of the consumer of the source.
- rewriter.setInsertionPoint(candidateSliceOp);
+ // 3 - Perform tiling of the cloned consumer.
+ rewriter.setInsertionPointAfter(clonedConsumerOp);
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
rewriter,
- cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
+ cast<OffsetSizeAndStrideOpInterface>(
+ clonedCandidateSliceOp.getOperation()),
clonedConsumerOp->getOpOperand(operandNumber));
if (failed(tileAndFuseResult)) {
return rewriter.notifyMatchFailure(tileableConsumer,
"failed to tile consumer op: ");
}
- // Update the source of the candidateSlice to be the cloned consumer.
- SmallVector<Value> candidateSliceOpOperands =
- llvm::to_vector(candidateSliceOp->getOperands());
- candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
- auto bbArgs = newforOp.getBody()->getArguments();
- candidateSliceOpOperands[1] = bbArgs[1 + forOp.getInits().size() + 0];
- tensor::InsertSliceOp clonedCandidateSliceOp =
- mlir::clone(rewriter, candidateSliceOp,
- candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+ // 4 - Extract offset/sizes/strides required to create the tensor.insert_slice
+ // for each result of the consumer.
+ SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ clonedCandidateSliceOp, "containingOp's result yield with stride");
+ }
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ // Try to get iter domain position from input position.
+ rewriter.setInsertionPointAfter(clonedConsumerOp);
+ if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ tileableConsumer, "can't get iter domain position from input position");
+ }
- rewriter.replaceAllUsesWith(candidateSliceOp, candidateSliceOp.getSource());
- rewriter.eraseOp(clonedConsumerOp);
+ // Try to get all containing op result's position from iter domain position.
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+ llvm::SmallVector<OpFoldResult>>>
+ resultPositions(clonedConsumerOp->getNumResults());
+ for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+ if (failed(tileableConsumer.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultPositions[idx].first, resultPositions[idx].second))) {
+ return rewriter.notifyMatchFailure(
+ tileableConsumer,
+ "can't get result domain position from iter domain position");
+ }
+ }
- // Fix terminator.
+ // 5 - Fix terminator.
scf::YieldOp oldTerminatorOp =
static_cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
- // llvm::outs()<<"\n========= DB - 5 ===========\n"<<funcOp<<"\n";
-
- SmallVector<Value> newYieldOperands;
- for (Value val : oldTerminatorOp.getResults()) {
- if (val == candidateSliceOp.getSource()) {
- newYieldOperands.push_back(candidateSliceOp.getResult());
- } else {
- newYieldOperands.push_back(val);
- }
- }
- newYieldOperands.push_back(clonedCandidateSliceOp.getResult());
+ SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
rewriter.setInsertionPointAfter(oldTerminatorOp);
+ auto bbArgs = newforOp.getBody()->getArguments();
+ for (auto [idx, v] :
+ llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+ rewriter.getIndexAttr(1));
+ newYieldOperands.push_back(rewriter.create<tensor::InsertSliceOp>(
+ clonedCandidateSliceOp->getLoc(), v,
+ bbArgs[1 + forOp.getInits().size() + idx], resultPositions[idx].first,
+ resultPositions[idx].second, strides));
+ }
rewriter.create<scf::YieldOp>(loc, newYieldOperands);
rewriter.eraseOp(oldTerminatorOp);
- // Replace the result of for and consumer op.
+ // 6 - Replace the result of scf.for and consumer op.
for (auto result : llvm::enumerate(forOp.getResults())) {
rewriter.replaceAllUsesWith(result.value(),
newforOp->getResult(result.index()));
@@ -1318,9 +1326,12 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
newforOp->getResult(forOp.getInits().size() + consumerResult.index()));
}
- // Need to erase the old for.
+ rewriter.replaceOp(candidateSliceOp, clonedCandidateSliceOp);
+
+ // 7 - Need to erase the old scf.for.
rewriter.eraseOp(forOp);
rewriter.eraseOp(consumerOp);
+ rewriter.eraseOp(clonedConsumerOp);
return scf::SCFFuseConsumerOfSliceResult{
consumerOp, tileAndFuseResult->tiledOps[0]->getResult(0), {}};
@@ -1373,13 +1384,7 @@ tileAndFuseConsumerOfSliceSCFForall(
unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
unsigned resultNumber =
cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
- // Check that the consumer results in exactly one value.
- // TODO: Support fusion for consumers yielding more than one result.
- if (consumerOp->getResults().size() != 1) {
- return rewriter.notifyMatchFailure(
- consumerOp,
- "only those consumers returning exactly one result are supported");
- }
+
OpBuilder::InsertionGuard g(rewriter);
// Using candidateSliceOp->getParentOp() because we have the following case :-
// scf.forall.in_parallel {
@@ -1415,18 +1420,6 @@ tileAndFuseConsumerOfSliceSCFForall(
"consumer op taking the result of scf.forall as init is not supported");
}
- SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
-
- // Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
-
Location loc = forallOp.getLoc();
// Create new scf.forall op.
SmallVector<Value> newOuts(forallOp.getOutputs());
@@ -1444,49 +1437,100 @@ tileAndFuseConsumerOfSliceSCFForall(
loopBody, newLoopBody,
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
- // Clone the consumer after the parallel_insert_slice.
+ // 1 - Clone tensor.parallel_insert_slice after the original
+ // tensor.parallel_insert_slice.
rewriter.setInsertionPointAfter(candidateSliceOp);
- SmallVector<Value> newForOpBlockArgsForConsumerDest;
- for (unsigned i = loopBody->getNumArguments(),
- n = newLoopBody->getArguments().size();
- i < n; i++) {
- newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
- }
+ SmallVector<Value> candidateSliceOpOperands =
+ llvm::to_vector(candidateSliceOp->getOperands());
+ tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
+ mlir::clone(rewriter, candidateSliceOp,
+ candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+ LLVM_DEBUG(llvm::dbgs() << "Created a clone of the candidate slice op : "
+ << clonedCandidateSliceOp << "\n");
+
+ // 2 - Clone the consumer after the clone tensor.parallel_insert_slice.
+ rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
+ SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
+ newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
+ [](BlockArgument b) -> Value { return b; });
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+ tileableConsumer = clonedConsumerOp;
- // Replace scf.forall result's use in the consumer with parallel_insert_slice
- // source.
- rewriter.replaceAllUsesWith(forallOp.getResult(resultNumber),
- candidateSliceOp.getSource());
+ // 2.b - Replace all uses of the scf.forall's result use in the consumer with
+ // the source of the cloned tensor.parallel_insert_slice.
+ rewriter.replaceUsesWithIf(forallOp.getResult(resultNumber),
+ clonedCandidateSliceOp.getSource(),
+ [&](OpOperand &operand) {
+ return operand.getOwner() == clonedConsumerOp;
+ });
- // Generate the tiled implementation of the consumer of the source.
- rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+ // 3 - Perform tiling of the cloned consumer.
+ rewriter.setInsertionPoint(newforallOp.getTerminator());
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
rewriter,
- cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
+ cast<OffsetSizeAndStrideOpInterface>(
+ clonedCandidateSliceOp.getOperation()),
clonedConsumerOp->getOpOperand(operandNumber));
if (failed(tileAndFuseResult)) {
return rewriter.notifyMatchFailure(tileableConsumer,
"failed to tile consumer op: ");
}
- // Update the source of the candidateSlice to be the cloned consumer.
- rewriter.setInsertionPointAfter(candidateSliceOp);
- SmallVector<Value> candidateSliceOpOperands =
- llvm::to_vector(candidateSliceOp->getOperands());
- candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
- auto bbArgs = newforallOp.getBody()->getArguments();
- candidateSliceOpOperands[1] =
- bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + 0];
- tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
- mlir::clone(rewriter, candidateSliceOp,
- candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
- LLVM_DEBUG(llvm::dbgs() << "Created a clone of the candidate slice op : "
- << clonedCandidateSliceOp << "\n");
+ // 4 - Extract offset/sizes/strides required to create the
+ // tensor.parallel_insert_slice for each result of the consumer.
+ SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ clonedCandidateSliceOp, "containingOp's result yield with stride");
+ }
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ // Try to get iter domain position from input position.
+ rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
+ ;
+ if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ tileableConsumer, "can't get iter domain position from input position");
+ }
- rewriter.eraseOp(clonedConsumerOp);
+ // Try to get all containing op result's position from iter domain position.
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+ llvm::SmallVector<OpFoldResult>>>
+ resultPositions(clonedConsumerOp->getNumResults());
+ for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+ if (failed(tileableConsumer.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultPositions[idx].first, resultPositions[idx].second))) {
+ return rewriter.notifyMatchFailure(
+ tileableConsumer,
+ "can't get result domain position from iter domain position");
+ }
+ }
+
+ // 5 - Fix terminator.
+ scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+ SmallVector<Operation *> yieldingOps = llvm::map_to_vector(
+ newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
+ Operation *firstYieldOp = yieldingOps.front();
+ rewriter.setInsertionPoint(firstYieldOp);
+ auto bbArgs = newforallOp.getBody()->getArguments();
+ for (auto [idx, v] :
+ llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+ rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ firstYieldOp->getLoc(), v,
+ bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+ resultPositions[idx].first, resultPositions[idx].second, strides);
+ }
// Replace the result of scf.forall and consumer op.
for (auto result : llvm::enumerate(forallOp.getResults())) {
@@ -1501,9 +1545,12 @@ tileAndFuseConsumerOfSliceSCFForall(
consumerResult.index()));
}
- // Need to erase the old scf.forall and consumer.
+ // Need to erase the old scf.forall, consumer, cloned consumer and
+ // candidateSliceOp.
rewriter.eraseOp(forallOp);
rewriter.eraseOp(consumerOp);
+ rewriter.eraseOp(clonedConsumerOp);
+ rewriter.eraseOp(candidateSliceOp);
return scf::SCFFuseConsumerOfSliceResult{
consumerOp, tileAndFuseResult->tiledOps[0]->getResult(0), {}};
diff --git a/mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir
index 9eabb385303515..3d60e32bfa0ccb 100644
--- a/mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir
@@ -45,14 +45,14 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[MAT_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
-// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV]]] [32] [1]
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[INSERT_MAT]][%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
// CHECK-SAME: outs(%[[SLICE_OUT]] :
// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
-// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#2 :
@@ -111,9 +111,148 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
// CHECK-SAME: outs(%[[SLICE_OUT]] :
// CHECK: scf.forall.in_parallel {
-// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+module {
+ func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32xf32>
+ %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+ scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+ }
+ %in_operand_2 = tensor.empty() : tensor<64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64xf32>
+ %out_operand_4 = tensor.empty() : tensor<64xf32>
+ %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.subf %out_0, %13 : f32
+ %15 = arith.addf %out_1, %in : f32
+ linalg.yield %14, %15 : f32, f32
+ } -> (tensor<64xf32>, tensor<64xf32>)
+ return %2#1 : tensor<64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %yield
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %0 = tensor.empty() : tensor<64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0)
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[INSERT_MAT]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+// CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+// CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#3 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ }
+ }
+ %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+ %out_operand_4 = tensor.empty() : tensor<64x64xf32>
+ %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64x64xf32>, tensor<64x64xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.subf %out_0, %13 : f32
+ %15 = arith.addf %out_1, %in : f32
+ linalg.yield %14, %15 : f32, f32
+ } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+ return %2#1 : tensor<64x64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+ : (!transform.any_op)
+ -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %first_slice_op
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
+// CHECK: %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#3 :
More information about the Mlir-commits
mailing list