[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Abhishek Varma llvmlistbot at llvm.org
Wed May 15 03:48:43 PDT 2024


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/88712

>From 3b493574f946bdd2a792d8e36753a1566752d93d Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Sat, 16 Mar 2024 09:12:12 +0800
Subject: [PATCH 1/5] [mlir][linalg] Enable fuse consumer

This patch adds support for consumer fusion to the tiling interface.

- Add interface method 'getIterationDomainTileFromOperandTile' to tiling
  interface which get iteration domain position from operand position.
- Add interface method 'getTiledImplementationFromOperandTile' to tiling
  interface which generate tiled implementation according to operand position.
---
 .../mlir/Interfaces/TilingInterface.td        |  67 ++++++++++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |   4 +-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 106 +++++++++++++-----
 .../Tensor/IR/TensorTilingInterfaceImpl.cpp   |  12 +-
 4 files changed, 149 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c2424..84f7dec2f4003 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           The method returns the operation that is the tiled
           implementation.
         }],
-        /*retType=*/"FailureOr<TilingResult>",
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
         /*methodName=*/"getTiledImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           by the tiled implementation. Expects the same `offsets` and `sizes` as
           used to obtain the tiled implementation of the operation.
         }],
-        /*retType=*/"LogicalResult",
+        /*retType=*/"::mlir::LogicalResult",
         /*methodName=*/"getResultTilePosition",
         /*args=*/(ins
           "OpBuilder &":$b,
           "unsigned":$resultNumber,
           "ArrayRef<OpFoldResult> ":$offsets,
           "ArrayRef<OpFoldResult> ":$sizes,
-          "SmallVector<OpFoldResult> &":$resultOffsets,
-          "SmallVector<OpFoldResult> &":$resultSizes),
+          "SmallVectorImpl<OpFoldResult> &":$resultOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$resultSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of iteration domain tile computed by the
+          tiled operation.
+        }],
+        /*retType=*/"::mlir::LogicalResult",
+        /*methodName=*/"getIterationDomainTileFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
             iteration space).
           - `sizes` provides the size of the tile.
         }],
-        /*retType=*/"FailureOr<TilingResult>",
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
         /*methodName=*/"generateResultTileValue",
         /*args=*/(ins
           "OpBuilder &":$b,
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand tile position.
+
+          Generates the IR that computes the tiled implementation of an
+          operation from operand tile.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the tile of the
+          operand required. This method enables consumer fusion by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Generates the scalar implementation of the operation.
@@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           transformations are done, this method can be used to lower to scalar
           code that can then be lowered to LLVM or SPIR-V dialects.
         }],
-        /*retType=*/"LogicalResult",
+        /*retType=*/"::mlir::LogicalResult",
         /*methodName=*/"generateScalarImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..03716eaaa6358 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2454,8 +2454,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
 
 LogicalResult SoftmaxOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
-    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
-    SmallVector<OpFoldResult> &resultSizes) {
+    ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
+    SmallVectorImpl<OpFoldResult> &resultSizes) {
   if (resultNumber == 0) {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5..71e9c3771dcde 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
         }));
   }
 
-  // Instantiate the tiled implementation of the operation.
+  /// Instantiate the tiled implementation of the operation.
   FailureOr<TilingResult>
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
-  // Return the details of the output tile generated by the tiled
-  // implementation.
+  void
+  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
+                         ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes,
+                         SmallVectorImpl<OpFoldResult> &mappedOffsets,
+                         SmallVectorImpl<OpFoldResult> &mappedSizes) const {
+    unsigned numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[index] = value.offset;
+        mappedSizes[index] = value.size;
+      }
+    }
+    for (const auto &&[index, value] :
+         llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
+      mappedOffsets[dimPosition] = offsets[index];
+      mappedSizes[dimPosition] = sizes[index];
+    }
+  }
+
+  /// Return the details of the output tile generated by the tiled
+  /// implementation.
+  LogicalResult getIterationDomainTileFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+      SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return emitError(op->getLoc(),
+                       "unhandled get iter domain position when operand is not "
+                       "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
+  /// Return the details of the output tile generated by the tiled
+  /// implementation.
   LogicalResult
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
 
@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return emitError(
+          op->getLoc(),
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           mappedOffsets, mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index d25efcf50ec56..296c5fc7a5c2b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
     return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     // The iteration domain is over outer dimensions of packed layout. In this
     // context, the outer dimensions of `resultOffsets` are `offsets`. The
     // inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     resultOffsets = llvm::to_vector(offsets);
     resultSizes = llvm::to_vector(sizes);
     return success();

>From ec05f4cb6a58d93fdf26c5fa5b208978c83fc1bf Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 10 Apr 2024 10:41:46 +0000
Subject: [PATCH 2/5] [MLIR][SCF] Add an API to fuse consumer to a producer
 within scf loop

-- This commit adds an API to fuse consumer to a producer within
   scf.for/scf.forall loop.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../SCF/Transforms/TileUsingInterface.h       |  13 +
 .../Dialect/Tensor/Transforms/Transforms.h    |  13 +-
 .../SCF/Transforms/TileUsingInterface.cpp     | 429 ++++++++++++++++++
 .../SwapExtractSliceWithProducerPatterns.cpp  |  23 +
 .../tile-and-fuse-consumer.mlir               | 258 +++++++++++
 .../TestTilingInterfaceTransformOps.cpp       |  51 +++
 .../TestTilingInterfaceTransformOps.td        |  19 +
 7 files changed, 805 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..c744a17ae9191 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -14,6 +14,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 #include <deque>
 
@@ -239,6 +240,18 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
                                      TilingInterface consumer,
                                      const SCFTileAndFuseOptions &options);
 
+/// Fuse the consumer of the source of `candidateSliceOp` by computing the
+/// required slice of the consumer in-place.  Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
+/// value but does not delete the slice operation.
+struct SCFFuseConsumerOfSliceResult {
+  Operation *origConsumer;          // Original untiled consumer.
+  Operation *tiledAndFusedConsumer; // Tiled and fused consumer op.
+  SmallVector<Operation *> tiledOps;
+};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
 FailureOr<SmallVector<scf::ForOp>>
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index e8a09c4741043..98447cf62900d 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 namespace mlir {
 
@@ -22,7 +23,7 @@ namespace tensor {
 // Patterns
 //===----------------------------------------------------------------------===//
 
-/// Pattern to swap an `tensor.extract_slice` with its producer when the
+/// Method to swap an `tensor.extract_slice` with its producer when the
 /// producer implements the `TilingInterface`. The pattern itself does not
 /// provide a mechanism to control where the application happens. With use of
 /// transform dialect that control is done within the transform dialect. Other
@@ -30,6 +31,16 @@ namespace tensor {
 FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
 
+/// Method to swap an `tensor.insert_slice` with its consumer when the
+/// consumer implements the `TilingInterface`. The pattern itself does not
+/// provide a mechanism to control where the application happens. With use of
+/// transform dialect that control is done within the transform dialect. Other
+/// use cases can inherit from this pattern and add necessary controls.
+FailureOr<TilingResult>
+replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
+                                    OffsetSizeAndStrideOpInterface sliceOp,
+                                    OpOperand &consumerOp);
+
 //===----------------------------------------------------------------------===//
 // Populate functions.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69d..d4147e2b0602f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -16,9 +16,11 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -1100,6 +1102,433 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+  Value::use_range uses = result.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return failure();
+  }
+  OpOperand &operandUse = (*uses.begin());
+  Operation *userOp = operandUse.getOwner();
+  if (!isa<scf::YieldOp>(userOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Expected scf.yield to be the only user, but got -> "
+               << (*userOp));
+    return failure();
+  }
+  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+                               "be in the same block\n");
+    return failure();
+  }
+  return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1.  tensor.insert_slice has scf.yield as its only user.
+/// 2.  scf.for's corresponding result has only one use.
+static OpOperand *
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+  Value sliceResult = candidateSliceOp.getResult();
+  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+    return nullptr;
+  }
+  // Step 1. Fetch the corresponding output.
+  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+  unsigned resultNumber = yieldOpOperand.getOperandNumber();
+  // Step 2. Check containing op is scf.for.
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp) {
+    return nullptr;
+  }
+  Value resultingValue = forOp->getResult(resultNumber);
+
+  // Step 3. Check resulting value of scf.for has exactly one use.
+  if (!llvm::hasSingleElement(resultingValue.getUses())) {
+    return nullptr;
+  }
+
+  // Step 4. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp)) {
+    return nullptr;
+  }
+  if (containingOp->getBlock() != consumerOp->getBlock()) {
+    return nullptr;
+  }
+  return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+                                 tensor::InsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice.
+  OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
+  if (!consumerOpOperand) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  Operation *consumerOp = consumerOpOperand->getOwner();
+  unsigned operandNumber = consumerOpOperand->getOperandNumber();
+  unsigned resultNumber =
+      cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+
+  auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(candidateSliceOp);
+
+  // 2. Check consumer is not using scf.for's output as init.
+  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+
+  Location loc = forOp.getLoc();
+  SmallVector<Value> newOuts(forOp.getInits());
+  newOuts.append(dpsInits);
+
+  // 3. Create new scf.for op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                              forOp.getUpperBound(),
+                                              forOp.getStep(), newOuts);
+  // 4. Move the loop body to the new op.
+  Block *loopBody = forOp.getBody();
+  Block *newLoopBody = newforOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // 5.a. Clone consumer after the cloned tensor.insert_slice op.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  auto newForOpBlockArgsForConsumerDest =
+      newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // 5.b. Replace all uses of the loop result with the result of the cloned
+  // tensor.insert_slice.
+  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+    operandToReplace.set(candidateSliceOp.getResult());
+  });
+
+  // 6 - Perform tiling of the cloned consumer.
+  rewriter.setInsertionPointAfter(clonedConsumerOp);
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(clonedConsumerOp,
+                                       "failed to tile consumer op: ");
+  }
+
+  // 7 - Extract offset/sizes/strides required to create the tensor.insert_slice
+  // for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+  // 8. Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "containingOp's result yield with stride");
+  }
+  // 9. Try to get iter domain position from input position.
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  rewriter.setInsertionPointAfter(clonedConsumerOp);
+  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        clonedConsumerOp, "can't get iter domain position from input position");
+  }
+
+  // 10. Try to fetch the offset and size for all results of the cloned
+  // consumer. This would then be used to form the corresponding
+  // tensor.insert_slice later.
+  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(clonedConsumerOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
+      return rewriter.notifyMatchFailure(
+          clonedConsumerOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+
+  // 11. Fix terminator.
+  scf::YieldOp oldTerminatorOp =
+      cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+  SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  MutableArrayRef<BlockArgument> bbArgs = newforOp.getBody()->getArguments();
+  for (auto [idx, v] :
+       llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+                                      rewriter.getIndexAttr(1));
+    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        candidateSliceOp->getLoc(), v,
+        bbArgs[1 + forOp.getInits().size() + idx], resultOffsets[idx],
+        resultSizes[idx], strides);
+    newYieldOperands.push_back(newInsertSliceOp);
+  }
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+
+  // 12. Replace the result of scf.for and consumer op.
+  for (auto &&[oldResult, newResult] :
+       llvm::zip_first(forOp.getResults(), newforOp.getResults())) {
+    rewriter.replaceAllUsesWith(oldResult, newResult);
+  }
+
+  for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        oldValue, newforOp->getResult(forOp.getInits().size() + index));
+  }
+
+  // 13. Need to erase the old scf.for and the cloned consumer op.
+  rewriter.eraseOp(forOp);
+  rewriter.eraseOp(clonedConsumerOp);
+
+  return scf::SCFFuseConsumerOfSliceResult{
+      consumerOp, tileAndFuseResult->tiledOps[0], {}};
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static OpOperand *
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // Step 1. Fetch the corresponding output
+  Value sliceDest = candidateSliceOp.getDest();
+  auto iterArg = cast<BlockArgument>(sliceDest);
+  Operation *containingOp = iterArg.getOwner()->getParentOp();
+  // Step 2. Check that the containing op is scf.forall.
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    return nullptr;
+  }
+  Value resultingValue =
+      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  // Step 3. Check resulting value of scf.forall has exactly one use.
+  Value::use_range uses = resultingValue.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    return nullptr;
+  }
+
+  // Step 4. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp)) {
+    return nullptr;
+  }
+  if (containingOp->getBlock() != consumerOp->getBlock()) {
+    return nullptr;
+  }
+  return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the dest.
+  OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
+  if (!consumerOpOperand) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  Operation *consumerOp = consumerOpOperand->getOwner();
+  unsigned operandNumber = consumerOpOperand->getOperandNumber();
+  unsigned resultNumber =
+      cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  auto forallOp = cast<scf::ForallOp>(containingOp);
+
+  // 2. Check consumer is not using scf.forall's output as init.
+  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.forall as init is not supported");
+  }
+
+  Location loc = forallOp.getLoc();
+  // 3. Create new scf.forall op.
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+  // 4. Move the loop body to the new op.
+  rewriter.eraseOp(newforallOp.getTerminator());
+  Block *loopBody = forallOp.getBody();
+  Block *newLoopBody = newforallOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // 5.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  auto newForOpBlockArgsForConsumerDest =
+      newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // 5.b. Replace all uses of the scf.forall's result use in the consumer with
+  // the source of the cloned tensor.parallel_insert_slice.
+  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+    operandToReplace.set(candidateSliceOp.getSource());
+  });
+
+  // 6. Perform tiling of the cloned consumer.
+  rewriter.setInsertionPoint(newforallOp.getTerminator());
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(clonedConsumerOp,
+                                       "failed to tile consumer op: ");
+  }
+
+  // 7. Extract offset/sizes/strides required to create the
+  // tensor.parallel_insert_slice for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+  // 8. Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "containingOp's result yield with stride");
+  }
+  // 9. Try to get iter domain position from input position.
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
+  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        clonedConsumerOp, "can't get iter domain position from input position");
+  }
+
+  // 10. Try to fetch the offset and size for all results of the cloned
+  // consumer. This would then be used to form the corresponding
+  // tensor.parallel_insert_slice later.
+  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(clonedConsumerOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
+      return rewriter.notifyMatchFailure(
+          clonedConsumerOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+
+  // 11. Fix terminator.
+  scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+  rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+  Location firstYieldOpLoc =
+      (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+  MutableArrayRef<BlockArgument> bbArgs = newforallOp.getBody()->getArguments();
+  for (auto [idx, v] :
+       llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOpLoc, v,
+        bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+        resultOffsets[idx], resultSizes[idx], strides);
+  }
+
+  // 12. Replace the result of scf.forall and consumer op.
+  for (auto &&[oldResult, newResult] :
+       llvm::zip_first(forallOp.getResults(), newforallOp.getResults())) {
+    rewriter.replaceAllUsesWith(oldResult, newResult);
+  }
+
+  for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        oldValue, newforallOp->getResult(forallOp.getOutputs().size() + index));
+  }
+
+  // 13. Need to erase the old scf.forall and cloned consumer.
+  rewriter.eraseOp(forallOp);
+  rewriter.eraseOp(clonedConsumerOp);
+
+  return scf::SCFFuseConsumerOfSliceResult{
+      consumerOp, tileAndFuseResult->tiledOps[0], {}};
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+    return tileAndFuseConsumerOfSliceSCFFor(rewriter, sliceOp);
+  } else if (auto sliceOp =
+                 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+    return tileAndFuseConsumerOfSliceSCFForall(rewriter, sliceOp);
+  } else {
+    return failure();
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // lowerToLoopsUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 40d79c2053817..858adfc436164 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -40,3 +40,26 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
 
   return *tiledResult;
 }
+
+FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
+    OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
+    OpOperand &consumer) {
+  auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
+  if (!consumerOp)
+    return failure();
+
+  // `TilingInterface` currently only supports strides being 1.
+  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+        return !isConstantIntValue(ofr, 1);
+      }))
+    return failure();
+
+  FailureOr<TilingResult> tiledResult =
+      consumerOp.getTiledImplementationFromOperandTile(
+          builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
+          sliceOp.getMixedSizes());
+  if (failed(tiledResult))
+    return failure();
+
+  return *tiledResult;
+}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
new file mode 100644
index 0000000000000..3d60e32bfa0cc
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -0,0 +1,258 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.addf %out, %13 : f32
+          linalg.yield %14 : f32
+        } -> tensor<32xf32>
+      %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+    }
+    %in_operand_2 = tensor.empty() : tensor<64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64xf32>
+    %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
+    return %2 : tensor<64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %yield
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_for(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %0 = tensor.empty() : tensor<64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME:      iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[INSERT_MAT]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME:              ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+module {
+  func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+         tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+      }
+    }
+    %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+    %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
+    return %2 : tensor<64x64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+        : (!transform.any_op)
+        -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %first_slice_op
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_forall(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] :
+//      CHECK:      %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME:              ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.addf %out, %13 : f32
+          linalg.yield %14 : f32
+        } -> tensor<32xf32>
+      %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+    }
+    %in_operand_2 = tensor.empty() : tensor<64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64xf32>
+    %out_operand_4 = tensor.empty() : tensor<64xf32>
+    %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) {
+      ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.subf %out_0, %13 : f32
+          %15 = arith.addf %out_1, %in : f32
+          linalg.yield %14, %15 : f32, f32
+    } -> (tensor<64xf32>, tensor<64xf32>)
+    return %2#1 : tensor<64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %yield
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %0 = tensor.empty() : tensor<64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME:      iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0)
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[INSERT_MAT]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+//      CHECK:      %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME:              ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+//      CHECK:      %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+//      CHECK:      %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+//      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#3 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+         tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+      }
+    }
+    %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+    %out_operand_4 = tensor.empty() : tensor<64x64xf32>
+    %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64x64xf32>, tensor<64x64xf32>) {
+      ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.subf %out_0, %13 : f32
+          %15 = arith.addf %out_1, %in : f32
+          linalg.yield %14, %15 : f32, f32
+    } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+    return %2#1 : tensor<64x64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+        : (!transform.any_op)
+        -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %first_slice_op
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] :
+//      CHECK:      %[[SLICE_OPERAND1:.*]] = tensor.extract_slice %[[MAT_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME:              ins(%[[SLICE_OPERAND1]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#3 :
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 335db1a61f476..181d8ebc68f9e 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -160,6 +160,57 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
                         : DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestFuseConsumerOp
+//===----------------------------------------------------------------------===//
+
+/// Apply fusing of consumer transformation to all payload ops and store both
+/// the original consumer operation as well as the fused consumer operation.
+template <typename Range>
+static LogicalResult
+applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
+                  Range &&payloadOps, TransformResults &transformResults) {
+  SmallVector<Operation *> originalConsumerOps;
+  SmallVector<Operation *> fusedConsumerOps;
+
+  for (Operation *target : payloadOps) {
+    rewriter.setInsertionPoint(target);
+
+    FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+        scf::tileAndFuseConsumerOfSlice(rewriter, target);
+
+    if (failed(fuseConsumerResults))
+      return failure();
+
+    // Report back the relevant handles to the transform op.
+    originalConsumerOps.push_back(fuseConsumerResults->origConsumer);
+    fusedConsumerOps.push_back(fuseConsumerResults->tiledAndFusedConsumer);
+  }
+
+  transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
+  transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
+                                     TransformResults &transformResults,
+                                     TransformState &state) {
+  LogicalResult result =
+      applyFuseConsumer(rewriter, getOperation(),
+                        state.getPayloadOps(getTarget()), transformResults);
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseConsumerOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  producesHandle(getConsumer(), effects);
+  producesHandle(getFusedConsumer(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // TestTileUsingForallOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index ef42375e5286d..d55d746bd6aa9 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
   }];
 }
 
+def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Fuses the consumer of the operation pointed to by the target handle
+    using the options provided as attributes.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$consumer,
+                      TransformHandleTypeInterface:$fused_consumer);
+
+  let assemblyFormat = [{
+    $target attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

>From 4af443d4fc9ab995eb17ff8ed1ec2dd5f6742ae7 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 2 May 2024 05:29:36 +0000
Subject: [PATCH 3/5] WIP unify and move common code out

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 761 ++++++++++--------
 1 file changed, 414 insertions(+), 347 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index d4147e2b0602f..eac23bfec4fa4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1173,11 +1173,12 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
   return &operand;
 }
 
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place for scf.for.
-static FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
-                                 tensor::InsertSliceOp candidateSliceOp) {
+typedef struct helper1struct {
+  OpOperand *consumerOpOperand;
+};
+
+template <typename T>
+static helper1struct helper1(T candidateSliceOp) {
   // 1. Get the consumer of scf.for for the result yielded by
   // tensor.insert_slice.
   OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
@@ -1190,377 +1191,443 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
   unsigned resultNumber =
       cast<OpResult>(consumerOpOperand->get()).getResultNumber();
 
-  auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+  if (std::is_same<T, tensor::InsertSliceOp>::value) {
+    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
 
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(candidateSliceOp);
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(candidateSliceOp);
 
-  // 2. Check consumer is not using scf.for's output as init.
-  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
-  SmallVector<Value> dpsInits =
-      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
-    return rewriter.notifyMatchFailure(
-        consumerOp,
-        "consumer op taking the result of scf.for as init is not supported");
-  }
+    // 2. Check consumer is not using scf.for's output as init.
+    auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+    SmallVector<Value> dpsInits =
+        llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+    if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
+      return rewriter.notifyMatchFailure(
+          consumerOp,
+          "consumer op taking the result of scf.for as init is not supported");
+    }
 
-  Location loc = forOp.getLoc();
-  SmallVector<Value> newOuts(forOp.getInits());
-  newOuts.append(dpsInits);
-
-  // 3. Create new scf.for op.
-  rewriter.setInsertionPoint(consumerOp);
-  auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
-                                              forOp.getUpperBound(),
-                                              forOp.getStep(), newOuts);
-  // 4. Move the loop body to the new op.
-  Block *loopBody = forOp.getBody();
-  Block *newLoopBody = newforOp.getBody();
-  rewriter.mergeBlocks(
-      loopBody, newLoopBody,
-      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+    Location loc = forOp.getLoc();
+    SmallVector<Value> newOuts(forOp.getInits());
+    newOuts.append(dpsInits);
+
+    // 3. Create new scf.for op.
+    rewriter.setInsertionPoint(consumerOp);
+    auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                                forOp.getUpperBound(),
+                                                forOp.getStep(), newOuts);
+    // 4. Move the loop body to the new op.
+    Block *loopBody = forOp.getBody();
+    Block *newLoopBody = newforOp.getBody();
+    rewriter.mergeBlocks(
+        loopBody, newLoopBody,
+        newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+    helper1struct ret;
+    ret.consumerOpOperand = consumerOpOperand;
+    ret.newLoop = newforOp;
+    return ret;
+  }
+
+  /// Implementation of fusing consumer of a single slice by computing the
+  /// slice of the consumer in-place for scf.for.
+  static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+  tileAndFuseConsumerOfSliceSCFFor(RewriterBase & rewriter,
+                                   tensor::InsertSliceOp candidateSliceOp) {
+    // 1. Get the consumer of scf.for for the result yielded by
+    // tensor.insert_slice.
+    OpOperand *consumerOpOperand =
+        getUntiledConsumerFromSlice(candidateSliceOp);
+    if (!consumerOpOperand) {
+      return rewriter.notifyMatchFailure(candidateSliceOp,
+                                         "could not fetch consumer to fuse");
+    }
+    Operation *consumerOp = consumerOpOperand->getOwner();
+    unsigned operandNumber = consumerOpOperand->getOperandNumber();
+    unsigned resultNumber =
+        cast<OpResult>(consumerOpOperand->get()).getResultNumber();
 
-  // 5.a. Clone consumer after the cloned tensor.insert_slice op.
-  rewriter.setInsertionPointAfter(candidateSliceOp);
-  auto newForOpBlockArgsForConsumerDest =
-      newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
-  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
-      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
 
-  // 5.b. Replace all uses of the loop result with the result of the cloned
-  // tensor.insert_slice.
-  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
-  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
-    operandToReplace.set(candidateSliceOp.getResult());
-  });
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(candidateSliceOp);
 
-  // 6 - Perform tiling of the cloned consumer.
-  rewriter.setInsertionPointAfter(clonedConsumerOp);
-  FailureOr<TilingResult> tileAndFuseResult =
-      tensor::replaceInsertSliceWithTiledConsumer(
-          rewriter,
-          cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
-          clonedConsumerOp->getOpOperand(operandNumber));
-  if (failed(tileAndFuseResult)) {
-    return rewriter.notifyMatchFailure(clonedConsumerOp,
-                                       "failed to tile consumer op: ");
-  }
+    // 2. Check consumer is not using scf.for's output as init.
+    auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+    SmallVector<Value> dpsInits =
+        llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+    if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
+      return rewriter.notifyMatchFailure(
+          consumerOp,
+          "consumer op taking the result of scf.for as init is not supported");
+    }
 
-  // 7 - Extract offset/sizes/strides required to create the tensor.insert_slice
-  // for each result of the consumer.
-  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
-  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
-  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
-  // 8. Check all insert stride is 1.
-  if (llvm::any_of(strides, [](OpFoldResult stride) {
-        return !isConstantIntValue(stride, 1);
-      })) {
-    return rewriter.notifyMatchFailure(
-        candidateSliceOp, "containingOp's result yield with stride");
-  }
-  // 9. Try to get iter domain position from input position.
-  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-  rewriter.setInsertionPointAfter(clonedConsumerOp);
-  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-          iterDomainSizes))) {
-    return rewriter.notifyMatchFailure(
-        clonedConsumerOp, "can't get iter domain position from input position");
-  }
+    Location loc = forOp.getLoc();
+    SmallVector<Value> newOuts(forOp.getInits());
+    newOuts.append(dpsInits);
+
+    // 3. Create new scf.for op.
+    rewriter.setInsertionPoint(consumerOp);
+    auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                                forOp.getUpperBound(),
+                                                forOp.getStep(), newOuts);
+    // 4. Move the loop body to the new op.
+    Block *loopBody = forOp.getBody();
+    Block *newLoopBody = newforOp.getBody();
+    rewriter.mergeBlocks(
+        loopBody, newLoopBody,
+        newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+    // 5.a. Clone consumer after the cloned tensor.insert_slice op.
+    rewriter.setInsertionPointAfter(candidateSliceOp);
+    auto newForOpBlockArgsForConsumerDest =
+        newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
+    auto clonedConsumerOp =
+        cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+            rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+    // 5.b. Replace all uses of the loop result with the result of the cloned
+    // tensor.insert_slice.
+    OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+    rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+      operandToReplace.set(candidateSliceOp.getResult());
+    });
+
+    // 6 - Perform tiling of the cloned consumer.
+    rewriter.setInsertionPointAfter(clonedConsumerOp);
+    FailureOr<TilingResult> tileAndFuseResult =
+        tensor::replaceInsertSliceWithTiledConsumer(
+            rewriter,
+            cast<OffsetSizeAndStrideOpInterface>(
+                candidateSliceOp.getOperation()),
+            clonedConsumerOp->getOpOperand(operandNumber));
+    if (failed(tileAndFuseResult)) {
+      return rewriter.notifyMatchFailure(clonedConsumerOp,
+                                         "failed to tile consumer op: ");
+    }
 
-  // 10. Try to fetch the offset and size for all results of the cloned
-  // consumer. This would then be used to form the corresponding
-  // tensor.insert_slice later.
-  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
-  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-      totalNumResultsOfConsumer);
-  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
-  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
-    if (failed(clonedConsumerOp.getResultTilePosition(
-            rewriter, idx, iterDomainOffsets, iterDomainSizes,
-            resultOffsets[idx], resultSizes[idx]))) {
+    // 7 - Extract offset/sizes/strides required to create the
+    // tensor.insert_slice for each result of the consumer.
+    SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+    // 8. Check all insert stride is 1.
+    if (llvm::any_of(strides, [](OpFoldResult stride) {
+          return !isConstantIntValue(stride, 1);
+        })) {
+      return rewriter.notifyMatchFailure(
+          candidateSliceOp, "containingOp's result yield with stride");
+    }
+    // 9. Try to get iter domain position from input position.
+    SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+    rewriter.setInsertionPointAfter(clonedConsumerOp);
+    if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+            rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+            iterDomainSizes))) {
       return rewriter.notifyMatchFailure(
           clonedConsumerOp,
-          "can't get result domain position from iter domain position");
+          "can't get iter domain position from input position");
     }
-  }
-
-  // 11. Fix terminator.
-  scf::YieldOp oldTerminatorOp =
-      cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
-  SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
-  rewriter.setInsertionPointAfter(oldTerminatorOp);
-  MutableArrayRef<BlockArgument> bbArgs = newforOp.getBody()->getArguments();
-  for (auto [idx, v] :
-       llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
-    SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
-                                      rewriter.getIndexAttr(1));
-    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        candidateSliceOp->getLoc(), v,
-        bbArgs[1 + forOp.getInits().size() + idx], resultOffsets[idx],
-        resultSizes[idx], strides);
-    newYieldOperands.push_back(newInsertSliceOp);
-  }
-  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
-  rewriter.eraseOp(oldTerminatorOp);
-
-  // 12. Replace the result of scf.for and consumer op.
-  for (auto &&[oldResult, newResult] :
-       llvm::zip_first(forOp.getResults(), newforOp.getResults())) {
-    rewriter.replaceAllUsesWith(oldResult, newResult);
-  }
-
-  for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
-    rewriter.replaceAllUsesWith(
-        oldValue, newforOp->getResult(forOp.getInits().size() + index));
-  }
-
-  // 13. Need to erase the old scf.for and the cloned consumer op.
-  rewriter.eraseOp(forOp);
-  rewriter.eraseOp(clonedConsumerOp);
 
-  return scf::SCFFuseConsumerOfSliceResult{
-      consumerOp, tileAndFuseResult->tiledOps[0], {}};
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static OpOperand *
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
-  // Step 1. Fetch the corresponding output
-  Value sliceDest = candidateSliceOp.getDest();
-  auto iterArg = cast<BlockArgument>(sliceDest);
-  Operation *containingOp = iterArg.getOwner()->getParentOp();
-  // Step 2. Check that the containing op is scf.forall.
-  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
-  if (!forallOp) {
-    return nullptr;
-  }
-  Value resultingValue =
-      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-  // Step 3. Check resulting value of scf.forall has exactly one use.
-  Value::use_range uses = resultingValue.getUses();
-  if (!llvm::hasSingleElement(uses)) {
-    return nullptr;
-  }
-
-  // Step 4. Get uses.
-  OpOperand &operand = (*resultingValue.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.forall, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp)) {
-    return nullptr;
-  }
-  if (containingOp->getBlock() != consumerOp->getBlock()) {
-    return nullptr;
-  }
-  return &operand;
-}
-
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place for scf.forall.
-static FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSliceSCFForall(
-    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
-  // 1. Get the consumer of the dest.
-  OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
-  if (!consumerOpOperand) {
-    return rewriter.notifyMatchFailure(candidateSliceOp,
-                                       "could not fetch consumer to fuse");
-  }
-  Operation *consumerOp = consumerOpOperand->getOwner();
-  unsigned operandNumber = consumerOpOperand->getOperandNumber();
-  unsigned resultNumber =
-      cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+    // 10. Try to fetch the offset and size for all results of the cloned
+    // consumer. This would then be used to form the corresponding
+    // tensor.insert_slice later.
+    unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+    SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+        totalNumResultsOfConsumer);
+    SmallVector<SmallVector<OpFoldResult>> resultSizes(
+        totalNumResultsOfConsumer);
+    for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+      if (failed(clonedConsumerOp.getResultTilePosition(
+              rewriter, idx, iterDomainOffsets, iterDomainSizes,
+              resultOffsets[idx], resultSizes[idx]))) {
+        return rewriter.notifyMatchFailure(
+            clonedConsumerOp,
+            "can't get result domain position from iter domain position");
+      }
+    }
 
-  OpBuilder::InsertionGuard g(rewriter);
-  // Using candidateSliceOp->getParentOp() because we have the following case :-
-  // scf.forall.in_parallel {
-  //   tensor.parallel_insert_slice ...
-  // }
-  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+    // 11. Fix terminator.
+    scf::YieldOp oldTerminatorOp =
+        cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+    SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+    rewriter.setInsertionPointAfter(oldTerminatorOp);
+    MutableArrayRef<BlockArgument> bbArgs = newforOp.getBody()->getArguments();
+    for (auto [idx, v] :
+         llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+      SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+                                        rewriter.getIndexAttr(1));
+      Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+          candidateSliceOp->getLoc(), v,
+          bbArgs[1 + forOp.getInits().size() + idx], resultOffsets[idx],
+          resultSizes[idx], strides);
+      newYieldOperands.push_back(newInsertSliceOp);
+    }
+    rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+    rewriter.eraseOp(oldTerminatorOp);
 
-  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
-  auto forallOp = cast<scf::ForallOp>(containingOp);
+    // 12. Replace the result of scf.for and consumer op.
+    for (auto &&[oldResult, newResult] :
+         llvm::zip_first(forOp.getResults(), newforOp.getResults())) {
+      rewriter.replaceAllUsesWith(oldResult, newResult);
+    }
 
-  // 2. Check consumer is not using scf.forall's output as init.
-  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
-  SmallVector<Value> dpsInits =
-      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
-    return rewriter.notifyMatchFailure(
-        consumerOp,
-        "consumer op taking the result of scf.forall as init is not supported");
-  }
+    for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
+      rewriter.replaceAllUsesWith(
+          oldValue, newforOp->getResult(forOp.getInits().size() + index));
+    }
 
-  Location loc = forallOp.getLoc();
-  // 3. Create new scf.forall op.
-  SmallVector<Value> newOuts(forallOp.getOutputs());
-  newOuts.append(dpsInits);
-  rewriter.setInsertionPoint(consumerOp);
-  auto newforallOp = rewriter.create<scf::ForallOp>(
-      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
-      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
-
-  // 4. Move the loop body to the new op.
-  rewriter.eraseOp(newforallOp.getTerminator());
-  Block *loopBody = forallOp.getBody();
-  Block *newLoopBody = newforallOp.getBody();
-  rewriter.mergeBlocks(
-      loopBody, newLoopBody,
-      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+    // 13. Need to erase the old scf.for and the cloned consumer op.
+    rewriter.eraseOp(forOp);
+    rewriter.eraseOp(clonedConsumerOp);
+
+    return scf::SCFFuseConsumerOfSliceResult{
+        consumerOp, tileAndFuseResult->tiledOps[0], {}};
+  }
+
+  /// Fetch the first untiled consumer of a scf.forall's result which is yielded
+  /// by a tensor.parallel_insert_slice.
+  static OpOperand *getUntiledConsumerFromSlice(
+      tensor::ParallelInsertSliceOp candidateSliceOp) {
+    // Step 1. Fetch the corresponding output
+    Value sliceDest = candidateSliceOp.getDest();
+    auto iterArg = cast<BlockArgument>(sliceDest);
+    Operation *containingOp = iterArg.getOwner()->getParentOp();
+    // Step 2. Check that the containing op is scf.forall.
+    auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+    if (!forallOp) {
+      return nullptr;
+    }
+    Value resultingValue =
+        forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+    // Step 3. Check resulting value of scf.forall has exactly one use.
+    Value::use_range uses = resultingValue.getUses();
+    if (!llvm::hasSingleElement(uses)) {
+      return nullptr;
+    }
 
-  // 5.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
-  rewriter.setInsertionPointAfter(candidateSliceOp);
-  auto newForOpBlockArgsForConsumerDest =
-      newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
-  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
-      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
-
-  // 5.b. Replace all uses of the scf.forall's result use in the consumer with
-  // the source of the cloned tensor.parallel_insert_slice.
-  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
-  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
-    operandToReplace.set(candidateSliceOp.getSource());
-  });
-
-  // 6. Perform tiling of the cloned consumer.
-  rewriter.setInsertionPoint(newforallOp.getTerminator());
-  FailureOr<TilingResult> tileAndFuseResult =
-      tensor::replaceInsertSliceWithTiledConsumer(
-          rewriter,
-          cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
-          clonedConsumerOp->getOpOperand(operandNumber));
-  if (failed(tileAndFuseResult)) {
-    return rewriter.notifyMatchFailure(clonedConsumerOp,
-                                       "failed to tile consumer op: ");
-  }
+    // Step 4. Get uses.
+    OpOperand &operand = (*resultingValue.getUses().begin());
+    Operation *consumerOp = operand.getOwner();
+    // TODO: We have to init result of consumer before scf.forall, use
+    //       DestinationStyleOpInterface to get result shape from init for now.
+    //       Add support for other op such as op has InferTypeOpInterface.
+    if (!isa<TilingInterface>(consumerOp) ||
+        !isa<DestinationStyleOpInterface>(consumerOp)) {
+      return nullptr;
+    }
+    if (containingOp->getBlock() != consumerOp->getBlock()) {
+      return nullptr;
+    }
+    return &operand;
+  }
+
+  /// Implementation of fusing consumer of a single slice by computing the
+  /// slice of the consumer in-place for scf.forall.
+  static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+  tileAndFuseConsumerOfSliceSCFForall(
+      RewriterBase & rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+    // 1. Get the consumer of the dest.
+    OpOperand *consumerOpOperand =
+        getUntiledConsumerFromSlice(candidateSliceOp);
+    if (!consumerOpOperand) {
+      return rewriter.notifyMatchFailure(candidateSliceOp,
+                                         "could not fetch consumer to fuse");
+    }
+    Operation *consumerOp = consumerOpOperand->getOwner();
+    unsigned operandNumber = consumerOpOperand->getOperandNumber();
+    unsigned resultNumber =
+        cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    // Using candidateSliceOp->getParentOp() because we have the following case
+    // :- scf.forall.in_parallel {
+    //   tensor.parallel_insert_slice ...
+    // }
+    rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+    Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+    auto forallOp = cast<scf::ForallOp>(containingOp);
+
+    // 2. Check consumer is not using scf.forall's output as init.
+    auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+    SmallVector<Value> dpsInits =
+        llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+    if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
+      return rewriter.notifyMatchFailure(consumerOp,
+                                         "consumer op taking the result of "
+                                         "scf.forall as init is not supported");
+    }
 
-  // 7. Extract offset/sizes/strides required to create the
-  // tensor.parallel_insert_slice for each result of the consumer.
-  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
-  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
-  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
-  // 8. Check all insert stride is 1.
-  if (llvm::any_of(strides, [](OpFoldResult stride) {
-        return !isConstantIntValue(stride, 1);
-      })) {
-    return rewriter.notifyMatchFailure(
-        candidateSliceOp, "containingOp's result yield with stride");
-  }
-  // 9. Try to get iter domain position from input position.
-  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-  rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
-  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-          iterDomainSizes))) {
-    return rewriter.notifyMatchFailure(
-        clonedConsumerOp, "can't get iter domain position from input position");
-  }
+    Location loc = forallOp.getLoc();
+    // 3. Create new scf.forall op.
+    SmallVector<Value> newOuts(forallOp.getOutputs());
+    newOuts.append(dpsInits);
+    rewriter.setInsertionPoint(consumerOp);
+    auto newforallOp = rewriter.create<scf::ForallOp>(
+        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+        forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+    // 4. Move the loop body to the new op.
+    rewriter.eraseOp(newforallOp.getTerminator());
+    Block *loopBody = forallOp.getBody();
+    Block *newLoopBody = newforallOp.getBody();
+    rewriter.mergeBlocks(
+        loopBody, newLoopBody,
+        newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+    // 5.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
+    rewriter.setInsertionPointAfter(candidateSliceOp);
+    auto newForOpBlockArgsForConsumerDest =
+        newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
+    auto clonedConsumerOp =
+        cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+            rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+    // 5.b. Replace all uses of the scf.forall's result use in the consumer with
+    // the source of the cloned tensor.parallel_insert_slice.
+    OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+    rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+      operandToReplace.set(candidateSliceOp.getSource());
+    });
+
+    // 6. Perform tiling of the cloned consumer.
+    rewriter.setInsertionPoint(newforallOp.getTerminator());
+    FailureOr<TilingResult> tileAndFuseResult =
+        tensor::replaceInsertSliceWithTiledConsumer(
+            rewriter,
+            cast<OffsetSizeAndStrideOpInterface>(
+                candidateSliceOp.getOperation()),
+            clonedConsumerOp->getOpOperand(operandNumber));
+    if (failed(tileAndFuseResult)) {
+      return rewriter.notifyMatchFailure(clonedConsumerOp,
+                                         "failed to tile consumer op: ");
+    }
 
-  // 10. Try to fetch the offset and size for all results of the cloned
-  // consumer. This would then be used to form the corresponding
-  // tensor.parallel_insert_slice later.
-  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
-  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-      totalNumResultsOfConsumer);
-  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
-  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
-    if (failed(clonedConsumerOp.getResultTilePosition(
-            rewriter, idx, iterDomainOffsets, iterDomainSizes,
-            resultOffsets[idx], resultSizes[idx]))) {
+    // 7. Extract offset/sizes/strides required to create the
+    // tensor.parallel_insert_slice for each result of the consumer.
+    SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+    // 8. Check all insert stride is 1.
+    if (llvm::any_of(strides, [](OpFoldResult stride) {
+          return !isConstantIntValue(stride, 1);
+        })) {
+      return rewriter.notifyMatchFailure(
+          candidateSliceOp, "containingOp's result yield with stride");
+    }
+    // 9. Try to get iter domain position from input position.
+    SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+    rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
+    if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+            rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+            iterDomainSizes))) {
       return rewriter.notifyMatchFailure(
           clonedConsumerOp,
-          "can't get result domain position from iter domain position");
+          "can't get iter domain position from input position");
     }
-  }
 
-  // 11. Fix terminator.
-  scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
-  rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
-  Location firstYieldOpLoc =
-      (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
-  MutableArrayRef<BlockArgument> bbArgs = newforallOp.getBody()->getArguments();
-  for (auto [idx, v] :
-       llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
-    SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
-                                      rewriter.getIndexAttr(1));
-    rewriter.create<tensor::ParallelInsertSliceOp>(
-        firstYieldOpLoc, v,
-        bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
-        resultOffsets[idx], resultSizes[idx], strides);
-  }
-
-  // 12. Replace the result of scf.forall and consumer op.
-  for (auto &&[oldResult, newResult] :
-       llvm::zip_first(forallOp.getResults(), newforallOp.getResults())) {
-    rewriter.replaceAllUsesWith(oldResult, newResult);
-  }
+    // 10. Try to fetch the offset and size for all results of the cloned
+    // consumer. This would then be used to form the corresponding
+    // tensor.parallel_insert_slice later.
+    unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+    SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+        totalNumResultsOfConsumer);
+    SmallVector<SmallVector<OpFoldResult>> resultSizes(
+        totalNumResultsOfConsumer);
+    for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+      if (failed(clonedConsumerOp.getResultTilePosition(
+              rewriter, idx, iterDomainOffsets, iterDomainSizes,
+              resultOffsets[idx], resultSizes[idx]))) {
+        return rewriter.notifyMatchFailure(
+            clonedConsumerOp,
+            "can't get result domain position from iter domain position");
+      }
+    }
 
-  for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
-    rewriter.replaceAllUsesWith(
-        oldValue, newforallOp->getResult(forallOp.getOutputs().size() + index));
-  }
+    // 11. Fix terminator.
+    scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+    rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+    Location firstYieldOpLoc =
+        (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+    MutableArrayRef<BlockArgument> bbArgs =
+        newforallOp.getBody()->getArguments();
+    for (auto [idx, v] :
+         llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+      SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+                                        rewriter.getIndexAttr(1));
+      rewriter.create<tensor::ParallelInsertSliceOp>(
+          firstYieldOpLoc, v,
+          bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+          resultOffsets[idx], resultSizes[idx], strides);
+    }
 
-  // 13. Need to erase the old scf.forall and cloned consumer.
-  rewriter.eraseOp(forallOp);
-  rewriter.eraseOp(clonedConsumerOp);
+    // 12. Replace the result of scf.forall and consumer op.
+    for (auto &&[oldResult, newResult] :
+         llvm::zip_first(forallOp.getResults(), newforallOp.getResults())) {
+      rewriter.replaceAllUsesWith(oldResult, newResult);
+    }
 
-  return scf::SCFFuseConsumerOfSliceResult{
-      consumerOp, tileAndFuseResult->tiledOps[0], {}};
-}
+    for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
+      rewriter.replaceAllUsesWith(
+          oldValue,
+          newforallOp->getResult(forallOp.getOutputs().size() + index));
+    }
 
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place.
-FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
-                                      Operation *candidateSliceOp) {
-  if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
-    return tileAndFuseConsumerOfSliceSCFFor(rewriter, sliceOp);
-  } else if (auto sliceOp =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
-    return tileAndFuseConsumerOfSliceSCFForall(rewriter, sliceOp);
-  } else {
-    return failure();
+    // 13. Need to erase the old scf.forall and cloned consumer.
+    rewriter.eraseOp(forallOp);
+    rewriter.eraseOp(clonedConsumerOp);
+
+    return scf::SCFFuseConsumerOfSliceResult{
+        consumerOp, tileAndFuseResult->tiledOps[0], {}};
+  }
+
+  /// Implementation of fusing consumer of a single slice by computing the
+  /// slice of the consumer in-place.
+  FailureOr<scf::SCFFuseConsumerOfSliceResult>
+  mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase & rewriter,
+                                        Operation * candidateSliceOp) {
+    if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+      return tileAndFuseConsumerOfSliceSCFFor(rewriter, sliceOp);
+    } else if (auto sliceOp =
+                   dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+      return tileAndFuseConsumerOfSliceSCFForall(rewriter, sliceOp);
+    } else {
+      return failure();
+    }
   }
-}
 
-//===----------------------------------------------------------------------===//
-// lowerToLoopsUsingSCFForOp implementation.
-//===----------------------------------------------------------------------===//
+  //===----------------------------------------------------------------------===//
+  // lowerToLoopsUsingSCFForOp implementation.
+  //===----------------------------------------------------------------------===//
 
-FailureOr<SmallVector<scf::ForOp>>
-mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
-                                     TilingInterface op) {
-  // TODO: Handle cases where the op has results if needed.
-  if (op->getNumResults() > 0) {
-    return rewriter.notifyMatchFailure(
-        op, "unable to lower to loops operations with return values");
-  }
+  FailureOr<SmallVector<scf::ForOp>> mlir::scf::lowerToLoopsUsingSCFForOp(
+      RewriterBase & rewriter, TilingInterface op) {
+    // TODO: Handle cases where the op has results if needed.
+    if (op->getNumResults() > 0) {
+      return rewriter.notifyMatchFailure(
+          op, "unable to lower to loops operations with return values");
+    }
 
-  SmallVector<Range> domain = op.getIterationDomain(rewriter);
-  SmallVector<Value> ivs;
-  SmallVector<scf::ForOp> loops;
-  Location loc = op.getLoc();
-  for (auto loopRange : domain) {
-    Value offsetVal =
-        getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
-    Value sizeVal =
-        getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
-    Value strideVal =
-        getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
-    auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
-                                            strideVal, ValueRange{});
-    loops.push_back(loop);
-    ivs.push_back(loop.getInductionVar());
-    rewriter.setInsertionPoint(loop.getBody()->getTerminator());
-  }
-  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
-    return failure();
+    SmallVector<Range> domain = op.getIterationDomain(rewriter);
+    SmallVector<Value> ivs;
+    SmallVector<scf::ForOp> loops;
+    Location loc = op.getLoc();
+    for (auto loopRange : domain) {
+      Value offsetVal =
+          getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+      Value sizeVal =
+          getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+      Value strideVal =
+          getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
+      auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
+                                              strideVal, ValueRange{});
+      loops.push_back(loop);
+      ivs.push_back(loop.getInductionVar());
+      rewriter.setInsertionPoint(loop.getBody()->getTerminator());
+    }
+    if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
+      return failure();
+    }
+    return loops;
   }
-  return loops;
-}

>From e656591f3e96bb84a88c8e5fef11abbc2a9e9585 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 7 May 2024 10:19:00 +0000
Subject: [PATCH 4/5] Unify the code for scf.for/forall fuse consumer + fix
 crash

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 703 +++++++-----------
 1 file changed, 285 insertions(+), 418 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index eac23bfec4fa4..2743c88a0c5b7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1173,14 +1173,113 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
   return &operand;
 }
 
-typedef struct helper1struct {
-  OpOperand *consumerOpOperand;
-};
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static OpOperand *
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // Step 1. Fetch the corresponding output
+  Value sliceDest = candidateSliceOp.getDest();
+  auto iterArg = cast<BlockArgument>(sliceDest);
+  Operation *containingOp = iterArg.getOwner()->getParentOp();
+  // Step 2. Check that the containing op is scf.forall.
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    return nullptr;
+  }
+  Value resultingValue =
+      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  // Step 3. Check resulting value of scf.forall has exactly one use.
+  Value::use_range uses = resultingValue.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    return nullptr;
+  }
+
+  // Step 4. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp)) {
+    return nullptr;
+  }
+  if (containingOp->getBlock() != consumerOp->getBlock()) {
+    return nullptr;
+  }
+  return &operand;
+}
+
+static OpOperand *getUntiledConsumerFromSlice(Operation *sliceOp) {
+  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(insertSlice);
+  } else if (auto parallelInsertSlice =
+                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(parallelInsertSlice);
+  } else {
+    return nullptr;
+  }
+}
+
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+                      TilingResult tilingResult,
+                      SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+                      SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+                      SmallVector<OpFoldResult> &strides, unsigned initSize) {
+  scf::YieldOp oldTerminatorOp =
+      cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+  SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  MutableArrayRef<BlockArgument> bbArgs = newForOp.getBody()->getArguments();
+  Location loc = newForOp.getLoc();
+  for (auto [idx, v] :
+       llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+                                      rewriter.getIndexAttr(1));
+    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, v, bbArgs[1 + initSize + idx], resultOffsets[idx],
+        resultSizes[idx], strides);
+    newYieldOperands.push_back(newInsertSliceOp);
+  }
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+}
+
+static void fixTerminatorSCFInParallel(
+    RewriterBase &rewriter, scf::ForallOp newForallOp,
+    TilingResult tilingResult,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+    SmallVector<OpFoldResult> &strides, unsigned initSize, unsigned rank) {
+  scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
+  rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+  Location firstYieldOpLoc =
+      (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+  MutableArrayRef<BlockArgument> bbArgs = newForallOp.getBody()->getArguments();
+  for (auto [idx, v] :
+       llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOpLoc, v, bbArgs[rank + initSize + idx], resultOffsets[idx],
+        resultSizes[idx], strides);
+  }
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+          candidateSliceOp))
+    return failure();
+
+  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
 
-template <typename T>
-static helper1struct helper1(T candidateSliceOp) {
   // 1. Get the consumer of scf.for for the result yielded by
-  // tensor.insert_slice.
+  // tensor.insert_slice/parallel_insert_slice.
   OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
   if (!consumerOpOperand) {
     return rewriter.notifyMatchFailure(candidateSliceOp,
@@ -1191,443 +1290,211 @@ static helper1struct helper1(T candidateSliceOp) {
   unsigned resultNumber =
       cast<OpResult>(consumerOpOperand->get()).getResultNumber();
 
-  if (std::is_same<T, tensor::InsertSliceOp>::value) {
-    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+  Operation *oldLoopOp = nullptr;
+  SmallVector<Value> newOuts;
+  Block *oldLoopBody = nullptr;
+  unsigned initSize = 0;
+  unsigned rank = 1;
+  if (isInsertSliceOp) {
+    auto forOp = candidateSliceOp->template getParentOfType<scf::ForOp>();
+    SmallVector<Value> forOpOuts(forOp.getInits());
+    oldLoopOp = forOp;
+    newOuts = forOpOuts;
+    oldLoopBody = forOp.getBody();
+    initSize = forOp.getInits().size();
+  } else {
+    auto forallOp = candidateSliceOp->template getParentOfType<scf::ForallOp>();
+    SmallVector<Value> forallOpOuts(forallOp.getOutputs());
+    oldLoopOp = forallOp;
+    newOuts = forallOpOuts;
+    oldLoopBody = forallOp.getBody();
+    initSize = forallOp.getOutputs().size();
+    rank = forallOp.getRank();
+  }
 
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(candidateSliceOp);
+  OpBuilder::InsertionGuard g(rewriter);
 
-    // 2. Check consumer is not using scf.for's output as init.
-    auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
-    SmallVector<Value> dpsInits =
-        llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-    if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
-      return rewriter.notifyMatchFailure(
-          consumerOp,
-          "consumer op taking the result of scf.for as init is not supported");
-    }
+  // 2. Check consumer is not using scf loop's output as init.
+  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+  newOuts.append(dpsInits);
 
-    Location loc = forOp.getLoc();
-    SmallVector<Value> newOuts(forOp.getInits());
-    newOuts.append(dpsInits);
+  Location loc = oldLoopOp->getLoc();
 
-    // 3. Create new scf.for op.
-    rewriter.setInsertionPoint(consumerOp);
-    auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+  // 3. Create new scf loop op.
+  rewriter.setInsertionPoint(consumerOp);
+  Operation *newLoopOp = nullptr;
+  Block *newLoopBody = nullptr;
+  if (isInsertSliceOp) {
+    auto forOp = cast<scf::ForOp>(oldLoopOp);
+    auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
                                                 forOp.getUpperBound(),
                                                 forOp.getStep(), newOuts);
-    // 4. Move the loop body to the new op.
-    Block *loopBody = forOp.getBody();
-    Block *newLoopBody = newforOp.getBody();
-    rewriter.mergeBlocks(
-        loopBody, newLoopBody,
-        newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
-
-    helper1struct ret;
-    ret.consumerOpOperand = consumerOpOperand;
-    ret.newLoop = newforOp;
-    return ret;
-  }
-
-  /// Implementation of fusing consumer of a single slice by computing the
-  /// slice of the consumer in-place for scf.for.
-  static FailureOr<scf::SCFFuseConsumerOfSliceResult>
-  tileAndFuseConsumerOfSliceSCFFor(RewriterBase & rewriter,
-                                   tensor::InsertSliceOp candidateSliceOp) {
-    // 1. Get the consumer of scf.for for the result yielded by
-    // tensor.insert_slice.
-    OpOperand *consumerOpOperand =
-        getUntiledConsumerFromSlice(candidateSliceOp);
-    if (!consumerOpOperand) {
-      return rewriter.notifyMatchFailure(candidateSliceOp,
-                                         "could not fetch consumer to fuse");
-    }
-    Operation *consumerOp = consumerOpOperand->getOwner();
-    unsigned operandNumber = consumerOpOperand->getOperandNumber();
-    unsigned resultNumber =
-        cast<OpResult>(consumerOpOperand->get()).getResultNumber();
-
-    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
-
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(candidateSliceOp);
-
-    // 2. Check consumer is not using scf.for's output as init.
-    auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
-    SmallVector<Value> dpsInits =
-        llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-    if (llvm::is_contained(dpsInits, forOp.getResult(resultNumber))) {
-      return rewriter.notifyMatchFailure(
-          consumerOp,
-          "consumer op taking the result of scf.for as init is not supported");
-    }
-
-    Location loc = forOp.getLoc();
-    SmallVector<Value> newOuts(forOp.getInits());
-    newOuts.append(dpsInits);
+    newLoopOp = newForOp;
+    newLoopBody = newForOp.getBody();
+  } else {
+    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+    auto newForallOp = rewriter.create<scf::ForallOp>(
+        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+        forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+    newLoopOp = newForallOp;
+    rewriter.eraseOp(newForallOp.getTerminator());
+    newLoopBody = newForallOp.getBody();
+  }
 
-    // 3. Create new scf.for op.
-    rewriter.setInsertionPoint(consumerOp);
-    auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
-                                                forOp.getUpperBound(),
-                                                forOp.getStep(), newOuts);
-    // 4. Move the loop body to the new op.
-    Block *loopBody = forOp.getBody();
-    Block *newLoopBody = newforOp.getBody();
-    rewriter.mergeBlocks(
-        loopBody, newLoopBody,
-        newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
-
-    // 5.a. Clone consumer after the cloned tensor.insert_slice op.
-    rewriter.setInsertionPointAfter(candidateSliceOp);
-    auto newForOpBlockArgsForConsumerDest =
-        newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
-    auto clonedConsumerOp =
-        cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
-            rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
-
-    // 5.b. Replace all uses of the loop result with the result of the cloned
-    // tensor.insert_slice.
-    OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
-    rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
-      operandToReplace.set(candidateSliceOp.getResult());
-    });
-
-    // 6 - Perform tiling of the cloned consumer.
-    rewriter.setInsertionPointAfter(clonedConsumerOp);
-    FailureOr<TilingResult> tileAndFuseResult =
-        tensor::replaceInsertSliceWithTiledConsumer(
-            rewriter,
-            cast<OffsetSizeAndStrideOpInterface>(
-                candidateSliceOp.getOperation()),
-            clonedConsumerOp->getOpOperand(operandNumber));
-    if (failed(tileAndFuseResult)) {
-      return rewriter.notifyMatchFailure(clonedConsumerOp,
-                                         "failed to tile consumer op: ");
+  // 4. Move the loop body to the new op.
+  unsigned oldNumArguments = oldLoopBody->getNumArguments();
+  rewriter.mergeBlocks(oldLoopBody, newLoopBody,
+                       newLoopBody->getArguments().take_front(oldNumArguments));
+
+  // 5.a. Clone consumer after the cloned
+  // tensor.insert_slice/parallel_insert_slice op.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  auto newForOpBlockArgsForConsumerDest =
+      newLoopBody->getArguments().drop_front(oldNumArguments);
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // 5.b. Replace all uses of the loop result with the result of the cloned
+  // tensor.insert_slice/parallel_insert_slice.
+  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+    if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+      operandToReplace.set(sliceOp.getResult());
+    } else if (auto sliceOp =
+                   dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+      operandToReplace.set(sliceOp.getSource());
     }
+  });
 
-    // 7 - Extract offset/sizes/strides required to create the
-    // tensor.insert_slice for each result of the consumer.
-    SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
-    // 8. Check all insert stride is 1.
-    if (llvm::any_of(strides, [](OpFoldResult stride) {
-          return !isConstantIntValue(stride, 1);
-        })) {
-      return rewriter.notifyMatchFailure(
-          candidateSliceOp, "containingOp's result yield with stride");
-    }
-    // 9. Try to get iter domain position from input position.
-    SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  // 6 - Perform tiling of the cloned consumer.
+  if (isInsertSliceOp) {
     rewriter.setInsertionPointAfter(clonedConsumerOp);
-    if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-            rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-            iterDomainSizes))) {
-      return rewriter.notifyMatchFailure(
-          clonedConsumerOp,
-          "can't get iter domain position from input position");
-    }
-
-    // 10. Try to fetch the offset and size for all results of the cloned
-    // consumer. This would then be used to form the corresponding
-    // tensor.insert_slice later.
-    unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
-    SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-        totalNumResultsOfConsumer);
-    SmallVector<SmallVector<OpFoldResult>> resultSizes(
-        totalNumResultsOfConsumer);
-    for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
-      if (failed(clonedConsumerOp.getResultTilePosition(
-              rewriter, idx, iterDomainOffsets, iterDomainSizes,
-              resultOffsets[idx], resultSizes[idx]))) {
-        return rewriter.notifyMatchFailure(
-            clonedConsumerOp,
-            "can't get result domain position from iter domain position");
-      }
-    }
-
-    // 11. Fix terminator.
-    scf::YieldOp oldTerminatorOp =
-        cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
-    SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
-    rewriter.setInsertionPointAfter(oldTerminatorOp);
-    MutableArrayRef<BlockArgument> bbArgs = newforOp.getBody()->getArguments();
-    for (auto [idx, v] :
-         llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
-      SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
-                                        rewriter.getIndexAttr(1));
-      Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-          candidateSliceOp->getLoc(), v,
-          bbArgs[1 + forOp.getInits().size() + idx], resultOffsets[idx],
-          resultSizes[idx], strides);
-      newYieldOperands.push_back(newInsertSliceOp);
-    }
-    rewriter.create<scf::YieldOp>(loc, newYieldOperands);
-    rewriter.eraseOp(oldTerminatorOp);
-
-    // 12. Replace the result of scf.for and consumer op.
-    for (auto &&[oldResult, newResult] :
-         llvm::zip_first(forOp.getResults(), newforOp.getResults())) {
-      rewriter.replaceAllUsesWith(oldResult, newResult);
-    }
-
-    for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
-      rewriter.replaceAllUsesWith(
-          oldValue, newforOp->getResult(forOp.getInits().size() + index));
-    }
-
-    // 13. Need to erase the old scf.for and the cloned consumer op.
-    rewriter.eraseOp(forOp);
-    rewriter.eraseOp(clonedConsumerOp);
-
-    return scf::SCFFuseConsumerOfSliceResult{
-        consumerOp, tileAndFuseResult->tiledOps[0], {}};
-  }
-
-  /// Fetch the first untiled consumer of a scf.forall's result which is yielded
-  /// by a tensor.parallel_insert_slice.
-  static OpOperand *getUntiledConsumerFromSlice(
-      tensor::ParallelInsertSliceOp candidateSliceOp) {
-    // Step 1. Fetch the corresponding output
-    Value sliceDest = candidateSliceOp.getDest();
-    auto iterArg = cast<BlockArgument>(sliceDest);
-    Operation *containingOp = iterArg.getOwner()->getParentOp();
-    // Step 2. Check that the containing op is scf.forall.
-    auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
-    if (!forallOp) {
-      return nullptr;
-    }
-    Value resultingValue =
-        forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-    // Step 3. Check resulting value of scf.forall has exactly one use.
-    Value::use_range uses = resultingValue.getUses();
-    if (!llvm::hasSingleElement(uses)) {
-      return nullptr;
-    }
+  } else {
+    rewriter.setInsertionPoint(cast<scf::ForallOp>(newLoopOp).getTerminator());
+  }
+  auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(clonedConsumerOp,
+                                       "failed to tile consumer op: ");
+  }
 
-    // Step 4. Get uses.
-    OpOperand &operand = (*resultingValue.getUses().begin());
-    Operation *consumerOp = operand.getOwner();
-    // TODO: We have to init result of consumer before scf.forall, use
-    //       DestinationStyleOpInterface to get result shape from init for now.
-    //       Add support for other op such as op has InferTypeOpInterface.
-    if (!isa<TilingInterface>(consumerOp) ||
-        !isa<DestinationStyleOpInterface>(consumerOp)) {
-      return nullptr;
-    }
-    if (containingOp->getBlock() != consumerOp->getBlock()) {
-      return nullptr;
-    }
-    return &operand;
-  }
-
-  /// Implementation of fusing consumer of a single slice by computing the
-  /// slice of the consumer in-place for scf.forall.
-  static FailureOr<scf::SCFFuseConsumerOfSliceResult>
-  tileAndFuseConsumerOfSliceSCFForall(
-      RewriterBase & rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
-    // 1. Get the consumer of the dest.
-    OpOperand *consumerOpOperand =
-        getUntiledConsumerFromSlice(candidateSliceOp);
-    if (!consumerOpOperand) {
-      return rewriter.notifyMatchFailure(candidateSliceOp,
-                                         "could not fetch consumer to fuse");
-    }
-    Operation *consumerOp = consumerOpOperand->getOwner();
-    unsigned operandNumber = consumerOpOperand->getOperandNumber();
-    unsigned resultNumber =
-        cast<OpResult>(consumerOpOperand->get()).getResultNumber();
-
-    OpBuilder::InsertionGuard g(rewriter);
-    // Using candidateSliceOp->getParentOp() because we have the following case
-    // :- scf.forall.in_parallel {
-    //   tensor.parallel_insert_slice ...
-    // }
-    rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
-
-    Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
-    auto forallOp = cast<scf::ForallOp>(containingOp);
-
-    // 2. Check consumer is not using scf.forall's output as init.
-    auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
-    SmallVector<Value> dpsInits =
-        llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-    if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
-      return rewriter.notifyMatchFailure(consumerOp,
-                                         "consumer op taking the result of "
-                                         "scf.forall as init is not supported");
-    }
+  // 7 - Extract offset/sizes/strides required to create the
+  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
 
-    Location loc = forallOp.getLoc();
-    // 3. Create new scf.forall op.
-    SmallVector<Value> newOuts(forallOp.getOutputs());
-    newOuts.append(dpsInits);
-    rewriter.setInsertionPoint(consumerOp);
-    auto newforallOp = rewriter.create<scf::ForallOp>(
-        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
-        forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+  // 8. Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "containingOp's result yield with stride");
+  }
 
-    // 4. Move the loop body to the new op.
-    rewriter.eraseOp(newforallOp.getTerminator());
-    Block *loopBody = forallOp.getBody();
-    Block *newLoopBody = newforallOp.getBody();
-    rewriter.mergeBlocks(
-        loopBody, newLoopBody,
-        newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
-
-    // 5.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
-    rewriter.setInsertionPointAfter(candidateSliceOp);
-    auto newForOpBlockArgsForConsumerDest =
-        newLoopBody->getArguments().drop_front(loopBody->getNumArguments());
-    auto clonedConsumerOp =
-        cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
-            rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
-
-    // 5.b. Replace all uses of the scf.forall's result use in the consumer with
-    // the source of the cloned tensor.parallel_insert_slice.
-    OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
-    rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
-      operandToReplace.set(candidateSliceOp.getSource());
-    });
-
-    // 6. Perform tiling of the cloned consumer.
-    rewriter.setInsertionPoint(newforallOp.getTerminator());
-    FailureOr<TilingResult> tileAndFuseResult =
-        tensor::replaceInsertSliceWithTiledConsumer(
-            rewriter,
-            cast<OffsetSizeAndStrideOpInterface>(
-                candidateSliceOp.getOperation()),
-            clonedConsumerOp->getOpOperand(operandNumber));
-    if (failed(tileAndFuseResult)) {
-      return rewriter.notifyMatchFailure(clonedConsumerOp,
-                                         "failed to tile consumer op: ");
-    }
+  // 9. Try to get iter domain position from input position.
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
 
-    // 7. Extract offset/sizes/strides required to create the
-    // tensor.parallel_insert_slice for each result of the consumer.
-    SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
-    // 8. Check all insert stride is 1.
-    if (llvm::any_of(strides, [](OpFoldResult stride) {
-          return !isConstantIntValue(stride, 1);
-        })) {
-      return rewriter.notifyMatchFailure(
-          candidateSliceOp, "containingOp's result yield with stride");
-    }
-    // 9. Try to get iter domain position from input position.
-    SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  if (isInsertSliceOp) {
+    rewriter.setInsertionPointAfter(clonedConsumerOp);
+  } else {
     rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
-    if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-            rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-            iterDomainSizes))) {
+  }
+  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        clonedConsumerOp, "can't get iter domain position from input position");
+  }
+
+  // 10. Try to fetch the offset and size for all results of the cloned
+  // consumer. This would then be used to form the corresponding
+  // tensor.insert_slice/parallel_insert_slice later.
+  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(clonedConsumerOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
       return rewriter.notifyMatchFailure(
           clonedConsumerOp,
-          "can't get iter domain position from input position");
+          "can't get result domain position from iter domain position");
     }
+  }
 
-    // 10. Try to fetch the offset and size for all results of the cloned
-    // consumer. This would then be used to form the corresponding
-    // tensor.parallel_insert_slice later.
-    unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
-    SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-        totalNumResultsOfConsumer);
-    SmallVector<SmallVector<OpFoldResult>> resultSizes(
-        totalNumResultsOfConsumer);
-    for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
-      if (failed(clonedConsumerOp.getResultTilePosition(
-              rewriter, idx, iterDomainOffsets, iterDomainSizes,
-              resultOffsets[idx], resultSizes[idx]))) {
-        return rewriter.notifyMatchFailure(
-            clonedConsumerOp,
-            "can't get result domain position from iter domain position");
-      }
-    }
+  if (isInsertSliceOp) {
+    fixTerminatorSCFYield(rewriter, cast<scf::ForOp>(newLoopOp),
+                          *tileAndFuseResult, resultOffsets, resultSizes,
+                          strides, initSize);
+  } else {
+    fixTerminatorSCFInParallel(rewriter, cast<scf::ForallOp>(newLoopOp),
+                               *tileAndFuseResult, resultOffsets, resultSizes,
+                               strides, initSize, rank);
+  }
 
-    // 11. Fix terminator.
-    scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
-    rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
-    Location firstYieldOpLoc =
-        (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
-    MutableArrayRef<BlockArgument> bbArgs =
-        newforallOp.getBody()->getArguments();
-    for (auto [idx, v] :
-         llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
-      SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
-                                        rewriter.getIndexAttr(1));
-      rewriter.create<tensor::ParallelInsertSliceOp>(
-          firstYieldOpLoc, v,
-          bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
-          resultOffsets[idx], resultSizes[idx], strides);
-    }
+  // 12. Replace the result of scf loop and consumer op with new loop's results.
+  for (auto &&[oldResult, newResult] :
+       llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
+    rewriter.replaceAllUsesWith(oldResult, newResult);
+  }
 
-    // 12. Replace the result of scf.forall and consumer op.
-    for (auto &&[oldResult, newResult] :
-         llvm::zip_first(forallOp.getResults(), newforallOp.getResults())) {
-      rewriter.replaceAllUsesWith(oldResult, newResult);
-    }
+  for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(oldValue,
+                                newLoopOp->getResult(initSize + index));
+  }
 
-    for (auto &&[index, oldValue] : llvm::enumerate(consumerOp->getResults())) {
-      rewriter.replaceAllUsesWith(
-          oldValue,
-          newforallOp->getResult(forallOp.getOutputs().size() + index));
-    }
+  // 13. Need to erase the old scf loop and the cloned consumer op.
+  rewriter.eraseOp(oldLoopOp);
+  rewriter.eraseOp(clonedConsumerOp);
 
-    // 13. Need to erase the old scf.forall and cloned consumer.
-    rewriter.eraseOp(forallOp);
-    rewriter.eraseOp(clonedConsumerOp);
+  return scf::SCFFuseConsumerOfSliceResult{
+      consumerOp, tileAndFuseResult->tiledOps[0], {}};
+}
 
-    return scf::SCFFuseConsumerOfSliceResult{
-        consumerOp, tileAndFuseResult->tiledOps[0], {}};
-  }
+//===----------------------------------------------------------------------===//
+// lowerToLoopsUsingSCFForOp implementation.
+//===----------------------------------------------------------------------===//
 
-  /// Implementation of fusing consumer of a single slice by computing the
-  /// slice of the consumer in-place.
-  FailureOr<scf::SCFFuseConsumerOfSliceResult>
-  mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase & rewriter,
-                                        Operation * candidateSliceOp) {
-    if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
-      return tileAndFuseConsumerOfSliceSCFFor(rewriter, sliceOp);
-    } else if (auto sliceOp =
-                   dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
-      return tileAndFuseConsumerOfSliceSCFForall(rewriter, sliceOp);
-    } else {
-      return failure();
-    }
+FailureOr<SmallVector<scf::ForOp>>
+mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
+                                     TilingInterface op) {
+  // TODO: Handle cases where the op has results if needed.
+  if (op->getNumResults() > 0) {
+    return rewriter.notifyMatchFailure(
+        op, "unable to lower to loops operations with return values");
   }
 
-  //===----------------------------------------------------------------------===//
-  // lowerToLoopsUsingSCFForOp implementation.
-  //===----------------------------------------------------------------------===//
-
-  FailureOr<SmallVector<scf::ForOp>> mlir::scf::lowerToLoopsUsingSCFForOp(
-      RewriterBase & rewriter, TilingInterface op) {
-    // TODO: Handle cases where the op has results if needed.
-    if (op->getNumResults() > 0) {
-      return rewriter.notifyMatchFailure(
-          op, "unable to lower to loops operations with return values");
-    }
-
-    SmallVector<Range> domain = op.getIterationDomain(rewriter);
-    SmallVector<Value> ivs;
-    SmallVector<scf::ForOp> loops;
-    Location loc = op.getLoc();
-    for (auto loopRange : domain) {
-      Value offsetVal =
-          getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
-      Value sizeVal =
-          getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
-      Value strideVal =
-          getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
-      auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
-                                              strideVal, ValueRange{});
-      loops.push_back(loop);
-      ivs.push_back(loop.getInductionVar());
-      rewriter.setInsertionPoint(loop.getBody()->getTerminator());
-    }
-    if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
-      return failure();
-    }
-    return loops;
+  SmallVector<Range> domain = op.getIterationDomain(rewriter);
+  SmallVector<Value> ivs;
+  SmallVector<scf::ForOp> loops;
+  Location loc = op.getLoc();
+  for (auto loopRange : domain) {
+    Value offsetVal =
+        getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+    Value sizeVal =
+        getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+    Value strideVal =
+        getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
+    auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
+                                            strideVal, ValueRange{});
+    loops.push_back(loop);
+    ivs.push_back(loop.getInductionVar());
+    rewriter.setInsertionPoint(loop.getBody()->getTerminator());
   }
+  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
+    return failure();
+  }
+  return loops;
+}

>From fc7b121375231b008f90441e2b53699ea50bed66 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 10 May 2024 07:15:32 +0000
Subject: [PATCH 5/5] [MLIR][Tensor] Add consumer tiling implementation for
 tensor.unpack

-- This commit adds tiling implementation for tensor.unpack as a consumer.

Signed-off-by: Abhishek Varma <avarma094 at gmail.com>
---
 .../Tensor/IR/TensorTilingInterfaceImpl.cpp   | 100 ++++++++++++++++++
 .../tile-and-fuse-consumer.mlir               |  60 +++++++++++
 2 files changed, 160 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 296c5fc7a5c2b..a2fc539e6de8f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -469,6 +469,106 @@ struct UnPackOpTiling
       return failure();
     return tilingResult.value();
   }
+
+  LogicalResult getIterationDomainTileFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVectorImpl<OpFoldResult> &resultOffsets,
+      SmallVectorImpl<OpFoldResult> &resultSizes) const {
+    auto unPackOp = cast<UnPackOp>(op);
+    Location loc = unPackOp.getLoc();
+
+    int64_t numTiles = unPackOp.getInnerDimsPos().size();
+    auto destOffsets = offsets.drop_back(numTiles);
+    auto destSizes = sizes.drop_back(numTiles);
+    // The tiling is applied on interchanged dimensions. We have to undo the
+    // interchange to map sizes and offsets to the original input.
+    int64_t outputRank = unPackOp.getDestRank();
+    SmallVector<OpFoldResult> origOffsets(destOffsets.begin(),
+                                          destOffsets.end());
+    SmallVector<OpFoldResult> origSizes(destSizes.begin(), destSizes.end());
+    applyPermToRange(origOffsets, origSizes,
+                     invertPermutationVector(unPackOp.getOuterDimsPerm()));
+
+    DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+        unPackOp.getDimAndTileMapping();
+
+    for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
+      using AV = affine::AffineValueExpr;
+      affine::AffineBuilder ab(b, loc);
+      AffineExpr dim0, dim1, sym;
+      bindDims(b.getContext(), dim0, dim1);
+      bindSymbols(b.getContext(), sym);
+      if (dimAndTileMapping.count(dim)) {
+        // If the data dimension is tiled, the i-th index is the product of
+        // offset_i and tile_i, and the i-th size is the product of sizes_i and
+        // tile_i.
+        auto avOffset = AV(dim0).bind(origOffsets[dim]);
+        auto avSize = AV(dim0).bind(origSizes[dim]);
+        auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
+        resultOffsets.push_back(ab.mul(avOffset, avTileSize));
+        resultSizes.push_back(ab.mul(avSize, avTileSize));
+      } else {
+        resultOffsets.push_back(origOffsets[dim]);
+        resultSizes.push_back(origSizes[dim]);
+      }
+    }
+    return success();
+  }
+
+  FailureOr<TilingResult>
+  getTiledImplementationAsConsumer(Operation *op, OpBuilder &b,
+                                   ArrayRef<OpFoldResult> offsets,
+                                   ArrayRef<OpFoldResult> sizes) const {
+    auto unPackOp = cast<UnPackOp>(op);
+    Location loc = unPackOp.getLoc();
+
+    // Fetch offset/size for creating the slice of the dest operand of
+    // unpack op.
+    SmallVector<OpFoldResult> outputOffsets, outputSizes;
+    if (failed(getIterationDomainTileFromOperandTile(
+            op, b, 0, offsets, sizes, outputOffsets, outputSizes)))
+      return failure();
+
+    auto oneAttr = b.getI64IntegerAttr(1);
+    int64_t outputRank = unPackOp.getDestRank();
+    SmallVector<OpFoldResult> strides(outputRank, oneAttr);
+
+    SmallVector<Value> tiledOperands;
+    // Create slice of the dest operand.
+    auto extractDestSlice = b.create<ExtractSliceOp>(
+        loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
+    tiledOperands.push_back(extractDestSlice);
+
+    SmallVector<OpFoldResult> inputOffsets, inputSizes;
+    strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
+    // Create slice of the source operand.
+    auto extractSourceSlice = b.create<ExtractSliceOp>(
+        loc, unPackOp.getSource(), offsets, sizes, strides);
+    tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
+    for (auto tile : unPackOp.getInnerTiles())
+      tiledOperands.push_back(tile);
+
+    // Create tiled unpack op.
+    Operation *tiledUnPackOp =
+        b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()},
+                           tiledOperands, op->getAttrs());
+
+    return TilingResult{{tiledUnPackOp},
+                        SmallVector<Value>(tiledUnPackOp->getResults())};
+  }
+
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    auto unpackOp = cast<UnPackOp>(op);
+    FailureOr<TilingResult> tilingResult =
+        getTiledImplementationAsConsumer(unpackOp, b, offsets, sizes);
+
+    if (failed(tilingResult))
+      return failure();
+    return tilingResult.value();
+  }
 };
 
 } // namespace
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 3d60e32bfa0cc..380f4cfb71036 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -256,3 +256,63 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#3 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+    func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
+        %c4 = arith.constant 4 : index
+        %c64 = arith.constant 64 : index
+        %c0 = arith.constant 0 : index
+        %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+            %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+            %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+                ^bb0(%in: f32, %in_16: f32, %out: f32):
+                %13 = arith.mulf %in, %in_16 : f32
+                %14 = arith.addf %out, %13 : f32
+                linalg.yield %14 : f32
+            } -> tensor<32x32xf32>
+            scf.forall.in_parallel {
+                tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+            }
+        }
+        %output = tensor.empty() : tensor<2048xf32>
+        %unpack = tensor.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
+        return %unpack : tensor<2048xf32>
+    }
+}
+  
+module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+        %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+        : (!transform.any_op) -> !transform.any_op
+        %a, %b = transform.test.fuse_consumer %slice_op
+        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+        transform.yield
+    }
+}
+//      CHECK: #[[UNPACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
+//      CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
+//      CHECK:      %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_MAP]](%[[IV1]])
+//      CHECK:      %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
+//      CHECK:      %[[TILED_UNPACK_SRC:.*]] = tensor.extract_slice %[[GENERIC_OUT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[TILED_UNPACK_SRC]]
+// CHECK-SAME:                              outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME:                              into %[[TILED_UNPACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
+//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#1 :



More information about the Mlir-commits mailing list