[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Abhishek Varma llvmlistbot at llvm.org
Mon Apr 22 03:33:06 PDT 2024


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/88712

>From 53e607e9467f38dc97656000439ea115d57621e3 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 10 Apr 2024 10:41:46 +0000
Subject: [PATCH] [MLIR][SCF] Add an API to fuse consumer to a producer within
 scf loop

-- This commit adds an API to fuse consumer to a producer within
   scf.for/scf.forall loop.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../SCF/Transforms/TileUsingInterface.h       |  13 +
 .../Dialect/Tensor/Transforms/Transforms.h    |  13 +-
 .../SCF/Transforms/TileUsingInterface.cpp     | 455 ++++++++++++++++++
 .../SwapExtractSliceWithProducerPatterns.cpp  |  23 +
 .../tile-and-fuse-consumer.mlir               | 258 ++++++++++
 .../TestTilingInterfaceTransformOps.cpp       |  51 ++
 .../TestTilingInterfaceTransformOps.td        |  19 +
 7 files changed, 831 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be28..c744a17ae9191f 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -14,6 +14,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 #include <deque>
 
@@ -239,6 +240,18 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
                                      TilingInterface consumer,
                                      const SCFTileAndFuseOptions &options);
 
+/// Fuse the consumer of the source of `candidateSliceOp` by computing the
+/// required slice of the consumer in-place.  Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
+/// value but does not delete the slice operation.
+struct SCFFuseConsumerOfSliceResult {
+  Operation *origConsumer;          // Original untiled consumer.
+  Operation *tiledAndFusedConsumer; // Tiled and fused consumer op.
+  SmallVector<Operation *> tiledOps;
+};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
 FailureOr<SmallVector<scf::ForOp>>
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index e8a09c4741043b..98447cf62900d5 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 namespace mlir {
 
@@ -22,7 +23,7 @@ namespace tensor {
 // Patterns
 //===----------------------------------------------------------------------===//
 
-/// Pattern to swap an `tensor.extract_slice` with its producer when the
+/// Method to swap an `tensor.extract_slice` with its producer when the
 /// producer implements the `TilingInterface`. The pattern itself does not
 /// provide a mechanism to control where the application happens. With use of
 /// transform dialect that control is done within the transform dialect. Other
@@ -30,6 +31,16 @@ namespace tensor {
 FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
 
+/// Method to swap an `tensor.insert_slice` with its consumer when the
+/// consumer implements the `TilingInterface`. The pattern itself does not
+/// provide a mechanism to control where the application happens. With use of
+/// transform dialect that control is done within the transform dialect. Other
+/// use cases can inherit from this pattern and add necessary controls.
+FailureOr<TilingResult>
+replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
+                                    OffsetSizeAndStrideOpInterface sliceOp,
+                                    OpOperand &consumerOp);
+
 //===----------------------------------------------------------------------===//
 // Populate functions.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69df..b97351bd5eeae1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -16,9 +16,11 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -1100,6 +1102,459 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+  Value::use_range uses = result.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return failure();
+  }
+  OpOperand &operandUse = (*uses.begin());
+  Operation *userOp = operandUse.getOwner();
+  if (!isa<scf::YieldOp>(userOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Expected scf.yield to be the only user, but got -> "
+               << (*userOp));
+    return failure();
+  }
+  return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1.  tensor.insert_slice has scf.yield as its only user.
+/// 2.  scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+  Value sliceResult = candidateSliceOp.getResult();
+  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+    return failure();
+  }
+  // Step 1. Fetch the corresponding output.
+  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+  unsigned resultNumber = yieldOpOperand.getOperandNumber();
+  // Step 2. Check containing op is scf.for.
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp) {
+    return failure();
+  }
+  Value resultingValue = forOp->getResult(resultNumber);
+
+  // Step 3. Check resulting value of scf.for has exactly one use.
+  if (!llvm::hasSingleElement(resultingValue.getUses())) {
+    return failure();
+  }
+
+  // Step 4. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp)) {
+    return failure();
+  }
+  return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+                                 tensor::InsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice.
+  FailureOr<OpOperand *> consumerOpOperand =
+      getUntiledConsumerFromSlice(candidateSliceOp);
+  if (failed(consumerOpOperand)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  Operation *consumerOp = (*consumerOpOperand)->getOwner();
+  unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+  unsigned resultNumber =
+      cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+  auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(candidateSliceOp);
+
+  auto dstOp = static_cast<DestinationStyleOpInterface>(consumerOp);
+  // 2. Check consumer is not using scf.for's output as init.
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, forOp.getResult(0))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+
+  Location loc = forOp.getLoc();
+  SmallVector<Value> newOuts(forOp.getInits());
+  newOuts.append(dpsInits);
+
+  // 3. Create new scf.for op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                              forOp.getUpperBound(),
+                                              forOp.getStep(), newOuts);
+  // 4. Move the loop body to the new op.
+  Block *loopBody = forOp.getBody();
+  Block *newLoopBody = newforOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // 5. Clone tensor.insert_slice after original tensor.insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  tensor::InsertSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+  // 6.a. Clone consumer after the cloned tensor.insert_slice op.
+  rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
+  SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
+      newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
+      [](BlockArgument b) -> Value { return b; });
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // 6.b. Replace all uses of the loop result with the result of the cloned
+  // tensor.insert_slice.
+  rewriter.replaceUsesWithIf(forOp.getResult(resultNumber),
+                             clonedCandidateSliceOp.getResult(),
+                             [&](OpOperand &operand) {
+                               return operand.getOwner() == clonedConsumerOp;
+                             });
+
+  // 7 - Perform tiling of the cloned consumer.
+  rewriter.setInsertionPointAfter(clonedConsumerOp);
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(
+              clonedCandidateSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(clonedConsumerOp,
+                                       "failed to tile consumer op: ");
+  }
+  assert(!(tileAndFuseResult->tiledOps.empty()) && "tiled consumer not found");
+
+  // 8 - Extract offset/sizes/strides required to create the tensor.insert_slice
+  // for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
+  // 9. Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        clonedCandidateSliceOp, "containingOp's result yield with stride");
+  }
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  // 10. Try to get iter domain position from input position.
+  rewriter.setInsertionPointAfter(clonedConsumerOp);
+  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        clonedConsumerOp, "can't get iter domain position from input position");
+  }
+
+  // 11. Try to get all containing op result's position from iter domain
+  // position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(clonedConsumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(clonedConsumerOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      return rewriter.notifyMatchFailure(
+          clonedConsumerOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+
+  // 12. Fix terminator.
+  scf::YieldOp oldTerminatorOp =
+      static_cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+  SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  auto bbArgs = newforOp.getBody()->getArguments();
+  for (auto [idx, v] :
+       llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                      rewriter.getIndexAttr(1));
+    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        clonedCandidateSliceOp->getLoc(), v,
+        bbArgs[1 + forOp.getInits().size() + idx], resultPositions[idx].first,
+        resultPositions[idx].second, strides);
+    newYieldOperands.push_back(newInsertSliceOp);
+  }
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+
+  // 13. Replace the result of scf.for and consumer op.
+  for (auto result : llvm::enumerate(forOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforOp->getResult(forOp.getInits().size() + consumerResult.index()));
+  }
+
+  rewriter.replaceOp(candidateSliceOp, clonedCandidateSliceOp);
+
+  // 14. Need to erase the old scf.for and the cloned consumer op.
+  rewriter.eraseOp(forOp);
+  rewriter.eraseOp(clonedConsumerOp);
+
+  return scf::SCFFuseConsumerOfSliceResult{
+      consumerOp, tileAndFuseResult->tiledOps[0], {}};
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // Step 1. Fetch the corresponding output
+  Value sliceDest = candidateSliceOp.getDest();
+  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+  Operation *containingOp = iterArg.getOwner()->getParentOp();
+  // Step 2. Check that the containing op is scf.forall.
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    return failure();
+  }
+  unsigned resultNumber = 0;
+  for (BlockArgument val : forallOp.getRegionOutArgs()) {
+    if (val == iterArg) {
+      break;
+    }
+    resultNumber++;
+  }
+  Value resultingValue = forallOp->getResult(resultNumber);
+  // Step 3. Check resulting value of scf.forall has exactly one use.
+  Value::use_range uses = resultingValue.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    return failure();
+  }
+
+  // Step 4. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp)) {
+    return failure();
+  }
+  return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the dest.
+  FailureOr<OpOperand *> consumerOpOperand =
+      getUntiledConsumerFromSlice(candidateSliceOp);
+  if (failed(consumerOpOperand)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  Operation *consumerOp = (*consumerOpOperand)->getOwner();
+  unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+  unsigned resultNumber =
+      cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  auto forallOp = static_cast<scf::ForallOp>(containingOp);
+
+  auto dstOp = static_cast<DestinationStyleOpInterface>(consumerOp);
+  // 2. Check consumer is not using scf.forall's output as init.
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.forall as init is not supported");
+  }
+
+  Location loc = forallOp.getLoc();
+  // 3. Create new scf.forall op.
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+  // 4. Move the loop body to the new op.
+  rewriter.eraseOp(newforallOp.getTerminator());
+  Block *loopBody = forallOp.getBody();
+  Block *newLoopBody = newforallOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // 5. Clone tensor.parallel_insert_slice after the original
+  // tensor.parallel_insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+  // 6.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
+  rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
+  SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
+      newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
+      [](BlockArgument b) -> Value { return b; });
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // 6.b. Replace all uses of the scf.forall's result use in the consumer with
+  // the source of the cloned tensor.parallel_insert_slice.
+  rewriter.replaceUsesWithIf(forallOp.getResult(resultNumber),
+                             clonedCandidateSliceOp.getSource(),
+                             [&](OpOperand &operand) {
+                               return operand.getOwner() == clonedConsumerOp;
+                             });
+
+  // 7. Perform tiling of the cloned consumer.
+  rewriter.setInsertionPoint(newforallOp.getTerminator());
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(
+              clonedCandidateSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(clonedConsumerOp,
+                                       "failed to tile consumer op: ");
+  }
+  assert(!(tileAndFuseResult->tiledOps.empty()) && "tiled consumer not found");
+
+  // 8. Extract offset/sizes/strides required to create the
+  // tensor.parallel_insert_slice for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
+  // 9. Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        clonedCandidateSliceOp, "containingOp's result yield with stride");
+  }
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  // 10. Try to get iter domain position from input position.
+  rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
+  ;
+  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        clonedConsumerOp, "can't get iter domain position from input position");
+  }
+
+  // 11. Try to get all containing op result's position from iter domain
+  // position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(clonedConsumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(clonedConsumerOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      return rewriter.notifyMatchFailure(
+          clonedConsumerOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+
+  // 12. Fix terminator.
+  scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+  SmallVector<Operation *> yieldingOps = llvm::map_to_vector(
+      newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
+  Operation *firstYieldOp = yieldingOps.front();
+  rewriter.setInsertionPoint(firstYieldOp);
+  auto bbArgs = newforallOp.getBody()->getArguments();
+  for (auto [idx, v] :
+       llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOp->getLoc(), v,
+        bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+        resultPositions[idx].first, resultPositions[idx].second, strides);
+  }
+
+  // 13. Replace the result of scf.forall and consumer op.
+  for (auto result : llvm::enumerate(forallOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforallOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforallOp->getResult(forallOp.getOutputs().size() +
+                               consumerResult.index()));
+  }
+
+  // 14. Need to erase the old scf.forall, cloned consumer and candidateSliceOp.
+  rewriter.eraseOp(forallOp);
+  rewriter.eraseOp(clonedConsumerOp);
+  rewriter.eraseOp(candidateSliceOp);
+
+  return scf::SCFFuseConsumerOfSliceResult{
+      consumerOp, tileAndFuseResult->tiledOps[0], {}};
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+    return tileAndFuseConsumerOfSliceSCFFor(rewriter, sliceOp);
+  } else if (auto sliceOp =
+                 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+    return tileAndFuseConsumerOfSliceSCFForall(rewriter, sliceOp);
+  } else {
+    return failure();
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // lowerToLoopsUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 40d79c20538172..858adfc436164d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -40,3 +40,26 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
 
   return *tiledResult;
 }
+
+FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
+    OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
+    OpOperand &consumer) {
+  auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
+  if (!consumerOp)
+    return failure();
+
+  // `TilingInterface` currently only supports strides being 1.
+  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+        return !isConstantIntValue(ofr, 1);
+      }))
+    return failure();
+
+  FailureOr<TilingResult> tiledResult =
+      consumerOp.getTiledImplementationFromOperandTile(
+          builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
+          sliceOp.getMixedSizes());
+  if (failed(tiledResult))
+    return failure();
+
+  return *tiledResult;
+}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
new file mode 100644
index 00000000000000..3d60e32bfa0ccb
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -0,0 +1,258 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.addf %out, %13 : f32
+          linalg.yield %14 : f32
+        } -> tensor<32xf32>
+      %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+    }
+    %in_operand_2 = tensor.empty() : tensor<64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64xf32>
+    %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
+    return %2 : tensor<64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %yield
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_for(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %0 = tensor.empty() : tensor<64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME:      iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[INSERT_MAT]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME:              ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+module {
+  func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+         tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+      }
+    }
+    %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+    %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
+    return %2 : tensor<64x64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+        : (!transform.any_op)
+        -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %first_slice_op
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_forall(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] :
+//      CHECK:      %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME:              ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.addf %out, %13 : f32
+          linalg.yield %14 : f32
+        } -> tensor<32xf32>
+      %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+    }
+    %in_operand_2 = tensor.empty() : tensor<64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64xf32>
+    %out_operand_4 = tensor.empty() : tensor<64xf32>
+    %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) {
+      ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.subf %out_0, %13 : f32
+          %15 = arith.addf %out_1, %in : f32
+          linalg.yield %14, %15 : f32, f32
+    } -> (tensor<64xf32>, tensor<64xf32>)
+    return %2#1 : tensor<64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %yield
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %0 = tensor.empty() : tensor<64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME:      iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0)
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[INSERT_MAT]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+//      CHECK:      %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME:              ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+//      CHECK:      %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+//      CHECK:      %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+//      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#3 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+         tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+      }
+    }
+    %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+    %out_operand_4 = tensor.empty() : tensor<64x64xf32>
+    %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64x64xf32>, tensor<64x64xf32>) {
+      ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.subf %out_0, %13 : f32
+          %15 = arith.addf %out_1, %in : f32
+          linalg.yield %14, %15 : f32, f32
+    } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+    return %2#1 : tensor<64x64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+        : (!transform.any_op)
+        -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %first_slice_op
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] :
+//      CHECK:      %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME:              ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#3 :
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 335db1a61f476e..181d8ebc68f9e7 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -160,6 +160,57 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
                         : DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestFuseConsumerOp
+//===----------------------------------------------------------------------===//
+
+/// Apply fusing of consumer transformation to all payload ops and store both
+/// the original consumer operation as well as the fused consumer operation.
+template <typename Range>
+static LogicalResult
+applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
+                  Range &&payloadOps, TransformResults &transformResults) {
+  SmallVector<Operation *> originalConsumerOps;
+  SmallVector<Operation *> fusedConsumerOps;
+
+  for (Operation *target : payloadOps) {
+    rewriter.setInsertionPoint(target);
+
+    FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+        scf::tileAndFuseConsumerOfSlice(rewriter, target);
+
+    if (failed(fuseConsumerResults))
+      return failure();
+
+    // Report back the relevant handles to the transform op.
+    originalConsumerOps.push_back(fuseConsumerResults->origConsumer);
+    fusedConsumerOps.push_back(fuseConsumerResults->tiledAndFusedConsumer);
+  }
+
+  transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
+  transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
+                                     TransformResults &transformResults,
+                                     TransformState &state) {
+  LogicalResult result =
+      applyFuseConsumer(rewriter, getOperation(),
+                        state.getPayloadOps(getTarget()), transformResults);
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseConsumerOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  producesHandle(getConsumer(), effects);
+  producesHandle(getFusedConsumer(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // TestTileUsingForallOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index ef42375e5286d8..d55d746bd6aa90 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
   }];
 }
 
+def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Fuses the consumer of the operation pointed to by the target handle
+    using the options provided as attributes.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$consumer,
+                      TransformHandleTypeInterface:$fused_consumer);
+
+  let assemblyFormat = [{
+    $target attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,



More information about the Mlir-commits mailing list