[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 15 20:59:23 PDT 2024


https://github.com/Yun-Fly updated https://github.com/llvm/llvm-project/pull/94190

>From 80a980334ac6e1d46ff918dd3f07832486ebc7a6 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Sun, 2 Jun 2024 23:32:22 -0700
Subject: [PATCH 1/3] extend consumer fuse to nested scf loop

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 731 ++++++++++++------
 .../tile-and-fuse-consumer.mlir               |  96 +++
 2 files changed, 606 insertions(+), 221 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a1392813d6de3..8a7eac1d883f9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1197,98 +1198,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
 // tileAndFuseConsumerUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
-/// A utility function that checks whether the only use of the result of a
-/// tensor.insert_slice op is in a scf.yield op.
-static LogicalResult
-checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
-  Value result = candidateSliceOp.getResult();
-  Value::use_range uses = result.getUses();
-  if (!llvm::hasSingleElement(uses)) {
-    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
-    return failure();
-  }
-  OpOperand &operandUse = (*uses.begin());
-  Operation *userOp = operandUse.getOwner();
-  if (!isa<scf::YieldOp>(userOp)) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "Expected scf.yield to be the only user, but got -> "
-               << (*userOp));
-    return failure();
-  }
-  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
-    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
-                               "be in the same block\n");
-    return failure();
-  }
-  return success();
-}
-
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
-                                                  Block *containingOpBlock) {
-  // Step 1. Check that the value has exactly one use.
-  if (!llvm::hasSingleElement(val.getUses()))
-    return failure();
-  // Step 2. Get uses.
-  OpOperand &operand = (*val.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.for, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp))
-    return failure();
-  if (containingOpBlock != consumerOp->getBlock())
-    return failure();
-  return &operand;
-}
-
-/// Fetch the untiled consumer of a scf.for's result which is yielded by a
-/// tensor.insert_slice. This function makes the following assumptions :
-/// 1.  tensor.insert_slice has scf.yield as its only user.
-/// 2.  scf.for's corresponding result has only one use.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
-  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
-    return failure();
-  Value sliceResult = candidateSliceOp.getResult();
-  // Step 1. Fetch the corresponding output.
-  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
-  unsigned resultNumber = yieldOpOperand.getOperandNumber();
-  // Step 2. Check containing op is scf.for.
-  Operation *containingOp = candidateSliceOp->getParentOp();
-  auto forOp = dyn_cast<scf::ForOp>(containingOp);
-  if (!forOp)
-    return failure();
-  Value resultingValue = forOp->getResult(resultNumber);
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
-  // Step 1. Fetch the corresponding output
-  Value sliceDest = candidateSliceOp.getDest();
-  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
-  if (!iterArg)
-    return failure();
-  Operation *containingOp = iterArg.getOwner()->getParentOp();
-  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
-    return failure();
-  // Step 2. Check that the containing op is scf.forall.
-  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
-  if (!forallOp)
-    return failure();
-  Value resultingValue =
-      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
 /// This utility currently checks whether the loop either :-
 /// 1. Yields exactly one result.
 /// 2. Has consumer op as its first user and other users to be in the same
@@ -1314,31 +1223,117 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
   return success();
 }
 
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
-  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(insertSlice);
-  } else if (auto parallelInsertSlice =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(parallelInsertSlice);
-  } else {
+/// Traverse and collect all outer loops of given sliceOp, sorted by
+/// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+    Operation *sliceOp,
+    std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+  assert(isa<OffsetSizeAndStrideOpInterface>(sliceOp));
+  SmallVector<LoopLikeOpInterface> outerLoops;
+  auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+  while (forOp) {
+    outerLoops.push_back(forOp);
+    if (untilLoop.has_value() && *untilLoop == forOp)
+      break;
+    forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+  }
+  return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+/// Get the result of top-level loop which yields the target InsertSliceOp. E.g
+/// ```
+/// %1 = scf.for
+///  %2 = scf.for
+///   %3 = scf.for
+///      ...
+///      %4 = insert
+///      yield %4
+///   %5 = insert %3
+///   yield %5
+///  yield %2
+/// ```
+/// @param targetSliceOp: %4 = insert
+/// @return first: resultValue: %1
+///         second: Collected insertSliceOp List during walk including
+///                   targetSliceOp: %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(Operation *targetSliceOp,
+                                          int curDepth = 0, int maxDepth = 5) {
+  assert(isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+  // Control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+  candidateSliceOpList.push_back(
+      cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+  Value resultOfLoop;
+  if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
+    Value destValue = sliceOp.getDest();
+    auto iterArg = cast<BlockArgument>(destValue);
+    auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+    if (!forallOp)
+      return failure();
+    resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
+    Value resultValue = sliceOp.getResult();
+    for (auto &useOperand : resultValue.getUses()) {
+      if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        if (llvm::detail::isPresent(resultOfLoop))
+          return failure();
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+      }
+    }
+  }
+
+  if (!llvm::detail::isPresent(resultOfLoop))
     return failure();
+
+  while (true) {
+    bool walkThroughOuterLoop = false;
+    for (OpOperand &useOperand : resultOfLoop.getUses()) {
+      if (auto sliceOp =
+              dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+        FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+            resultAndSliceOpsPair = getResultOfTopLevelLoopYieldInsertSliceOp(
+                sliceOp, curDepth + 1);
+        if (failed(resultAndSliceOpsPair))
+          return failure();
+        candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
+                                    (*resultAndSliceOpsPair).second.end());
+        return std::make_pair((*resultAndSliceOpsPair).first,
+                              candidateSliceOpList);
+      } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        // walk through outer loop
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+        walkThroughOuterLoop = true;
+        break;
+      }
+    }
+    if (!walkThroughOuterLoop)
+      break;
   }
+  return std::make_pair(resultOfLoop, candidateSliceOpList);
 }
 
 /// After fusing consumer into scf.for we want to modify the scf.yield operation
 /// to reflect the same by returning the values yielded by the tiled consumer.
 static void
 fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
-                      TilingResult &tilingResult,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                      ResultRange tilingResult,
+                      ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+                      ArrayRef<SmallVector<OpFoldResult>> resultSizes,
                       ArrayRef<BlockArgument> bbArgs) {
   scf::YieldOp oldTerminatorOp =
       cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
   unsigned totalOldResults = oldTerminatorOp->getNumResults();
-  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+  unsigned totalTiledResults = tilingResult.size();
   SmallVector<Value> newYieldOperands;
   newYieldOperands.reserve(totalOldResults + totalTiledResults);
   for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1347,8 +1342,7 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
   rewriter.setInsertionPointAfter(oldTerminatorOp);
   Location loc = newForOp.getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
-                       resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1363,16 +1357,16 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
 /// values by the tiled consumer within scf.forall.in_parallel region.
 static void
 fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
-                           SmallVector<Value> tiledResults,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                           ResultRange tilingResult,
+                           ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> resultSizes,
                            ArrayRef<BlockArgument> bbArgs) {
   scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
   rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
   Location firstYieldOpLoc =
       (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1380,6 +1374,176 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
   }
 }
 
+/// If the top level loop of nested loop structure is scf.forall, need to create
+/// additional tensor.extract_slice for its new appended `shared_outs` in order
+/// to pass correct local memory for inner loops. E.g.
+///
+/// scf.forall shared_outs(%o1=..., %o2=...) {
+///     %local_o1 = extract_slice %o1
+///     // fix new appended `shared_out` %o2
+///     %local_o2 = extract_slice %o2
+///     scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+///        ...
+///     }
+///     ...
+/// }
+static SmallVector<tensor::ExtractSliceOp> fixLoopInitFromSharedOutSCFForall(
+    RewriterBase &rewriter, Operation *loop, ArrayRef<Value> newSharedOuts,
+    ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+    ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
+  rewriter.setInsertionPoint(loop);
+  Location loc = loop->getLoc();
+  // create new ExtractOps for newInits from scf.forall
+  SmallVector<tensor::ExtractSliceOp> newExtractOps;
+  newExtractOps.reserve(resultOffsets.size());
+  for (auto [bbArg, offset, sizes] :
+       llvm::zip_equal(newSharedOuts, resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        loc, bbArg, offset, sizes, strides);
+    newExtractOps.push_back(newExtractOp);
+  }
+  return newExtractOps;
+}
+
+/// If outerMost loop of nested loop structure is `scf.forall`, need to deal
+/// with DpsInit of tiled consumer
+static void
+fixDpsInitsOfTiledConsumer(RewriterBase &rewriter, Operation *tiledConsumer,
+                           ArrayRef<BlockArgument> bbArgs,
+                           ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
+  rewriter.setInsertionPoint(tiledConsumer);
+  Location loc = tiledConsumer->getLoc();
+  for (auto &&[bbArg, offset, sizes, dpsInit] :
+       llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+                       cast<DestinationStyleOpInterface>(tiledConsumer)
+                           .getDpsInitsMutable())) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        loc, bbArg, offset, sizes, strides);
+    dpsInit.set(newExtractOp.getResult());
+  }
+}
+
+/// Compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+    RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+    OffsetSizeAndStrideOpInterface ossSliceOp,
+    SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+  // 1. Check all stride all 1
+  if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+  }
+  // 2. Compute iteration domain tile from the input position
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  if (failed(tilableOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
+          ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        tilableOp, "can't get iter domain position from input position");
+  }
+  unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  // 3. Compute result Tile by resultNumber
+  for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
+    if (failed(tilableOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
+      return rewriter.notifyMatchFailure(
+          tilableOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+  allResultOffsets = resultOffsets;
+  allResultSizes = resultSizes;
+  return success();
+}
+
+/// Considering multi-level tensor.*SliceOp maybe based on different
+/// coordination, this utility computes the real OFFSET coordinated on ROOT
+/// SliceOp. E.g
+///             %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
+///         %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
+///
+/// where the coordination can be illustrated as follow:
+///
+///  %3 ----------------------------------
+///  |         |         |
+///  | OFFSET2 | OFFSET1 |
+///  | ------ %0         |
+///  |                   |
+///  |                   |
+///  |------------------ %1 ------ |
+///  |                   |  SIZE1  |
+///  |                   |         |
+///  |                   |         |
+///  |                   | ------- |
+///  |
+///
+/// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
+static FailureOr<SmallVector<OpFoldResult>>
+computeRealOffsetsCoordinatedRootSliceOp(
+    RewriterBase &rewriter, Location loc,
+    OffsetSizeAndStrideOpInterface candidateSliceOp,
+    MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
+  if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "candidateSliceOp has stride");
+  }
+  SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
+  // Real offsets equals to accumulative offsets of outer candidates
+  for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
+       iter++) {
+    // assert each outer candidate slice has no stride
+    if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
+          return !isConstantIntValue(stride, 1);
+        })) {
+      return failure();
+    }
+    for (auto &&[ofr1, ofr2] :
+         llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
+      using AVE = affine::AffineValueExpr;
+      affine::AffineBuilder ab(rewriter, loc);
+      AffineExpr dim0, dim1, sym;
+      bindDims(rewriter.getContext(), dim0, dim1);
+      bindSymbols(rewriter.getContext(), sym);
+      auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
+      ofr1 = ab.add(aveOffset1, aveOffset2);
+    }
+  }
+  return realOffsets;
+}
+
+/// Get the first tilable user of given Value and check its domination at the
+/// same time
+static FailureOr<OpOperand *>
+getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
+  for (auto &useOfval : val.getUses()) {
+    Operation *consumerOp = useOfval.getOwner();
+    // 1. Check whether consumerOp is tilable
+    if (!isa<TilingInterface>(consumerOp) ||
+        !isa<DestinationStyleOpInterface>(consumerOp))
+      continue;
+    // 2. Check stay in same block with loopOp
+    if (loopOp->getBlock() != consumerOp->getBlock())
+      continue;
+    // 3. Check no other user before it
+    if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
+      continue;
+    }
+    return &useOfval;
+  }
+  return failure();
+}
+
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1391,10 +1555,28 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
 
-  // 1. Get the consumer of scf.for for the result yielded by
-  // tensor.insert_slice/parallel_insert_slice.
+  // 1.a. Get the real consumer of candidate
+  // tensor.insert_slice/parallel_insert_slice by walking through
+  // scf.for/scf.forall and collect all [Parallel]insertSliceOp(s) along the
+  // way.
+  FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+      resultAndSliceOpsPair =
+          getResultOfTopLevelLoopYieldInsertSliceOp(candidateSliceOp);
+  if (failed(resultAndSliceOpsPair)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+
+  // 1.b. Get all outer loops of candidateSliceOp.
+  SmallVector<LoopLikeOpInterface> outerLoops = getOuterLoopsOfSliceOp(
+      candidateSliceOp, dyn_cast<OpResult>((*resultAndSliceOpsPair).first)
+                            .getDefiningOp<LoopLikeOpInterface>());
+  LoopLikeOpInterface outerMostLoop = outerLoops.front();
+
+  // 2. Get first tilable consumer op
   FailureOr<OpOperand *> maybeConsumerOpOperand =
-      getUntiledConsumerFromSlice(candidateSliceOp);
+      getTilableConsumerOperandFirstUseVal((*resultAndSliceOpsPair).first,
+                                           outerMostLoop);
   if (failed(maybeConsumerOpOperand)) {
     return rewriter.notifyMatchFailure(candidateSliceOp,
                                        "could not fetch consumer to fuse");
@@ -1410,111 +1592,193 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
   }
 
-  Operation *oldLoopOp = nullptr;
-  SmallVector<Value> newOuts;
-  Block *oldLoopBody = nullptr;
-  unsigned initSize = 0;
-  unsigned rank = 1;
-  if (isInsertSliceOp) {
-    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
-    oldLoopOp = forOp;
-    llvm::append_range(newOuts, forOp.getInits());
-    oldLoopBody = forOp.getBody();
-    initSize = forOp.getInits().size();
-  } else {
-    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
-    oldLoopOp = forallOp;
-    llvm::append_range(newOuts, forallOp.getOutputs());
-    oldLoopBody = forallOp.getBody();
-    initSize = forallOp.getOutputs().size();
-    rank = forallOp.getRank();
-  }
-
-  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
-    return rewriter.notifyMatchFailure(
-        oldLoopOp, "containing loop op should either yield just one value or "
-                   "have the consumer op as its first user");
-  }
-
-  OpBuilder::InsertionGuard g(rewriter);
-
-  // 2. Check consumer is not using scf loop's output as init.
+  // 3. Check consumer is not using outerMostLoop's output as init.
   auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  newOuts.append(dpsInits);
 
-  Location loc = oldLoopOp->getLoc();
+  // 4.a. Reconstruct nested loop from outer to inner.
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList =
+      (*resultAndSliceOpsPair).second;
+  SmallVector<LoopLikeOpInterface> newOuterLoops;
+  SmallVector<SmallVector<OpFoldResult>> allResultOffsets, allResultSizes;
+  SmallVector<tensor::ExtractSliceOp> newExtractOps;
 
-  // 3. Create new scf loop op.
-  rewriter.setInsertionPoint(consumerOp);
-  Operation *newLoopOp = nullptr;
+  Block *oldLoopBody = nullptr;
   Block *newLoopBody = nullptr;
+  SmallVector<Value> newInitAppend = dpsInits, newOuts;
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // 4.b. Set insertPoint right before consumerOp
+  rewriter.setInsertionPoint(consumerOp);
+
+  for (auto [index, loop] :
+       llvm::enumerate(MutableArrayRef(outerLoops).drop_back())) {
+    if (index > 0) {
+      rewriter.setInsertionPoint(loop);
+      // 4.c. Create `extractSliceOp` for newInits if they comes from sharedOut
+      // of previous `scf.forall` loop.
+      if (auto prevOuterLoop =
+              dyn_cast<scf::ForallOp>(newOuterLoops.back().getOperation())) {
+        if (index != 1) {
+          return rewriter.notifyMatchFailure(
+              prevOuterLoop, "Currently only outerMostLoop assumed forallOp");
+        }
+        OffsetSizeAndStrideOpInterface outerMostCandidate =
+            candidateSliceOpList.back();
+        if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+                rewriter, cast<TilingInterface>(consumerOp), operandNumber,
+                outerMostCandidate, allResultOffsets, allResultSizes))) {
+          return failure();
+        }
+        newExtractOps = fixLoopInitFromSharedOutSCFForall(
+            rewriter, loop, newInitAppend, allResultOffsets, allResultSizes);
+        newInitAppend = llvm::map_to_vector(
+            newExtractOps,
+            [](tensor::ExtractSliceOp op) -> Value { return op.getResult(); });
+      }
+    }
+    LoopLikeOpInterface newLoopOp;
+    // 4.d. Create a new loop with the new init values for this loop.
+    if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
+      newOuts = llvm::to_vector(forOp.getInits());
+      newOuts.append(newInitAppend.begin(), newInitAppend.end());
+      auto newLoop = rewriter.create<scf::ForOp>(
+          forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+          forOp.getStep(), newOuts);
+      newLoopOp = newLoop;
+      oldLoopBody = forOp.getBody();
+      newLoopBody = newLoop.getBody();
+    } else if (auto forallOp = dyn_cast<scf::ForallOp>(loop.getOperation())) {
+      newOuts = llvm::to_vector(forallOp.getOutputs());
+      newOuts.append(newInitAppend.begin(), newInitAppend.end());
+      auto newLoop = rewriter.create<scf::ForallOp>(
+          forallOp.getLoc(), forallOp.getMixedLowerBound(),
+          forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+          forallOp.getMapping());
+      rewriter.eraseOp(newLoop.getTerminator());
+      newLoopOp = newLoop;
+      oldLoopBody = forallOp.getBody();
+      newLoopBody = newLoop.getBody();
+    }
+    newInitAppend = llvm::map_to_vector(
+        newLoopBody->getArguments().take_back(newInitAppend.size()),
+        [](BlockArgument bArg) -> Value { return bArg; });
+    rewriter.mergeBlocks(
+        oldLoopBody, newLoopBody,
+        newLoopBody->getArguments().take_front(oldLoopBody->getNumArguments()));
+    rewriter.replaceOp(
+        loop, newLoopOp->getResults().take_front(loop->getNumResults()));
+    newOuterLoops.push_back(newLoopOp);
+  }
+
+  // 5.a. Reconstruct inner-most loop.
+  LoopLikeOpInterface oldInnerMostLoop = outerLoops.back(), newInnerMostLoop;
+  Location loc = oldInnerMostLoop->getLoc();
+  if (outerLoops.size() > 1)
+    rewriter.setInsertionPoint(oldInnerMostLoop);
+
   if (isInsertSliceOp) {
-    auto forOp = cast<scf::ForOp>(oldLoopOp);
+    auto forOp = cast<scf::ForOp>(oldInnerMostLoop.getOperation());
+    newOuts = llvm::to_vector(forOp.getInits());
+    newOuts.append(newInitAppend.begin(), newInitAppend.end());
+    oldLoopBody = forOp.getBody();
     auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
                                                 forOp.getUpperBound(),
                                                 forOp.getStep(), newOuts);
-    newLoopOp = newForOp;
+    newInnerMostLoop = newForOp;
     newLoopBody = newForOp.getBody();
   } else {
-    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+    auto forallOp = cast<scf::ForallOp>(oldInnerMostLoop.getOperation());
+    newOuts = llvm::to_vector(forallOp.getOutputs());
+    newOuts.append(newInitAppend.begin(), newInitAppend.end());
+    oldLoopBody = forallOp.getBody();
     auto newForallOp = rewriter.create<scf::ForallOp>(
         loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
         forallOp.getMixedStep(), newOuts, forallOp.getMapping());
-    newLoopOp = newForallOp;
+    newInnerMostLoop = newForallOp;
     rewriter.eraseOp(newForallOp.getTerminator());
     newLoopBody = newForallOp.getBody();
   }
 
-  // 4. Move the loop body to the new op.
+  // 5.b. Move the loop body to the new op.
   unsigned oldNumArguments = oldLoopBody->getNumArguments();
   rewriter.mergeBlocks(oldLoopBody, newLoopBody,
                        newLoopBody->getArguments().take_front(oldNumArguments));
+  // 5.c. Replace the result of old oldInnerMostLoop with newInnerMostLoop's
+  // results.
+  rewriter.replaceOp(oldInnerMostLoop,
+                     newInnerMostLoop->getResults().take_front(
+                         oldInnerMostLoop->getNumResults()));
 
-  // 5. Set insertion point before terminator op of the loop and create a new
+  // 6. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
   tensor::InsertSliceOp clonedInsertSliceOp;
   if (auto sliceOp =
           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
-    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
     rewriter.setInsertionPoint(newForallOp.getTerminator());
-    clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
   } else {
     rewriter.setInsertionPoint(candidateSliceOp);
-    clonedInsertSliceOp =
-        cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
   }
 
-  // 6.a. Clone consumer op.
-  auto newForOpBlockArgsForConsumerDest =
-      newLoopBody->getArguments().drop_front(oldNumArguments);
+  // 7.a. Due to current assumption of `getTiledImplementation` that all
+  // `Operands` are untiled with original tensor size, create dummy
+  // `insertSliceOp` to align with that requirement.
+  auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
+  FailureOr<SmallVector<OpFoldResult>> realOffsets =
+      computeRealOffsetsCoordinatedRootSliceOp(rewriter, loc, ossSliceOp,
+                                               candidateSliceOpList);
+  if (failed(realOffsets))
+    return failure();
+  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, candidateSliceOp->getOperand(0),
+      candidateSliceOpList.back()->getOperand(1), *realOffsets,
+      ossSliceOp.getMixedSizes(), ossSliceOp.getMixedStrides());
+
+  SmallVector<Value> newDpsInitsForConsumerDest = llvm::map_to_vector(
+      newLoopBody->getArguments().drop_front(oldNumArguments),
+      [](BlockArgument bArg) -> Value { return bArg; });
+
+  // 7.b. If the outerMostLoop is scf.forall, then the `newExtractOps` has been
+  // additionally created at `step 4.e` for `dpsInits`. As the counterpart, the
+  // `insertSliceOp` is also needed for the same purpose with `step 7.a`.
+  if (!newExtractOps.empty()) {
+    for (auto &&[extractOp, newDpsInit] :
+         llvm::zip_equal(newExtractOps, newDpsInitsForConsumerDest)) {
+      auto alignDpsInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+          loc, newDpsInit, extractOp.getSource(), extractOp.getMixedOffsets(),
+          extractOp.getMixedSizes(), extractOp.getMixedStrides());
+      newDpsInit = alignDpsInsertSliceOp.getResult();
+    }
+  }
+
+  // 8.a. Clone consumer op.
   auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
-      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+      rewriter, consumerOp, newDpsInitsForConsumerDest));
 
-  // 6.b. Replace all uses of the loop result with the result of the cloned
+  // 8.b. Replace all uses of the loop result with the result of the cloned
   // tensor.insert_slice.
   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
     operandToReplace.set(clonedInsertSliceOp.getResult());
   });
 
-  // 7 - Perform tiling of the cloned consumer and replace the operand at
+  // 9. Perform tiling of the cloned consumer and replace the operand at
   // `operandNumber` with the source of the cloned tensor.insert_slice op.
-  auto ossSliceOp =
-      cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
   FailureOr<TilingResult> tileAndFuseResult =
       tensor::replaceInsertSliceWithTiledConsumer(
-          rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(
+              clonedInsertSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
   if (failed(tileAndFuseResult)) {
     return failure();
   }
@@ -1522,75 +1786,100 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
       tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
       clonedInsertSliceOp.getSource());
 
-  // 8 - Extract offset/sizes/strides required to create the
-  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
-  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
-  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
-  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
-  // 9. Check all insert stride is 1.
-  if (llvm::any_of(strides, [](OpFoldResult stride) {
-        return !isConstantIntValue(stride, 1);
-      })) {
-    return rewriter.notifyMatchFailure(
-        candidateSliceOp, "containingOp's result yield with stride");
-  }
-
-  // 10. Try to get iter domain position from input position.
-  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-          iterDomainSizes))) {
-    return rewriter.notifyMatchFailure(
-        clonedConsumerOp, "can't get iter domain position from input position");
-  }
-
-  // 11. Try to fetch the offset and size for all results of the cloned
+  // 10. Try to fetch the offset and size for all results of the cloned
   // consumer. This would then be used to form the corresponding
   // tensor.insert_slice/parallel_insert_slice later.
-  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
-  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-      totalNumResultsOfConsumer);
-  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
-  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
-    if (failed(clonedConsumerOp.getResultTilePosition(
-            rewriter, idx, iterDomainOffsets, iterDomainSizes,
-            resultOffsets[idx], resultSizes[idx]))) {
-      return rewriter.notifyMatchFailure(
-          clonedConsumerOp,
-          "can't get result domain position from iter domain position");
-    }
+  if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+          rewriter, clonedConsumerOp, operandNumber, ossSliceOp,
+          allResultOffsets, allResultSizes))) {
+    return failure();
+  }
+
+  if (!newExtractOps.empty()) {
+    fixDpsInitsOfTiledConsumer(
+        rewriter, tileAndFuseResult->tiledOps[0],
+        newLoopBody->getArguments().drop_front(oldNumArguments),
+        allResultOffsets, allResultSizes);
   }
 
-  auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
-  auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
   if (isInsertSliceOp) {
-    auto newForOp = cast<scf::ForOp>(newLoopOp);
+    auto newForOp = cast<scf::ForOp>(newInnerMostLoop);
     fixTerminatorSCFYield(
-        rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
-        newForOp.getBody()->getArguments().drop_front(1 + initSize));
+        rewriter, newForOp, tileAndFuseResult->tiledOps[0]->getResults(),
+        allResultOffsets, allResultSizes,
+        newForOp.getBody()->getArguments().take_back(newInitAppend.size()));
   } else {
-    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
     fixTerminatorSCFInParallel(
         rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
-        arrayRefOffsets, arrayRefSizes,
-        newForallOp.getBody()->getArguments().drop_front(rank + initSize));
+        allResultOffsets, allResultSizes,
+        newForallOp.getBody()->getArguments().take_back(newInitAppend.size()));
   }
 
-  // 12. Replace the result of scf loop and consumer op with new loop's results.
-  for (auto &&[oldResult, newResult] :
-       llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
-    rewriter.replaceAllUsesWith(oldResult, newResult);
+  newOuterLoops.push_back(cast<LoopLikeOpInterface>(newInnerMostLoop));
+
+  // 11.a. Reconstruct terminator of outer loop by inner loop.
+  auto outerCandidateIter = candidateSliceOpList.rbegin();
+  for (auto [outerLoop, innerLoop] :
+       llvm::zip_equal(MutableArrayRef(newOuterLoops).drop_back(),
+                       MutableArrayRef(newOuterLoops).drop_front())) {
+    // 11.b. Create insertSliceOp according outer candidateSliceOp
+    if (outerCandidateIter != candidateSliceOpList.rend() &&
+        outerCandidateIter->getOperation()
+                ->getParentOfType<LoopLikeOpInterface>() == outerLoop) {
+      if (auto forallOp = dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
+        rewriter.setInsertionPoint(forallOp.getTerminator());
+      } else {
+        rewriter.setInsertionPointAfter(*outerCandidateIter);
+      }
+
+      if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+              rewriter, clonedConsumerOp, operandNumber, *outerCandidateIter,
+              allResultOffsets, allResultSizes))) {
+        return failure();
+      }
+
+      if (auto forOp = dyn_cast<scf::ForOp>(outerLoop.getOperation())) {
+        fixTerminatorSCFYield(
+            rewriter, forOp,
+            innerLoop->getResults().take_back(newInitAppend.size()),
+            allResultOffsets, allResultSizes,
+            forOp.getBody()->getArguments().take_back(newInitAppend.size()));
+      } else if (auto forallOp =
+                     dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
+        fixTerminatorSCFInParallel(
+            rewriter, forallOp,
+            innerLoop->getResults().take_back(newInitAppend.size()),
+            allResultOffsets, allResultSizes,
+            forallOp.getBody()->getArguments().take_back(newInitAppend.size()));
+      }
+      outerCandidateIter++;
+    } else {
+      // 11.c. Yield additional new appended results of innerLoop
+      assert(isa<scf::ForOp>(outerLoop));
+      auto forOp = cast<scf::ForOp>(outerLoop);
+      auto outerLoopYield =
+          cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+      SmallVector<Value> newYields =
+          llvm::to_vector(outerLoopYield.getOperands());
+      ValueRange additionalYields =
+          innerLoop->getResults().take_back(newInitAppend.size());
+      newYields.append(additionalYields.begin(), additionalYields.end());
+      rewriter.setInsertionPoint(outerLoopYield);
+      rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
+    }
   }
 
+  // 12. Replace the result of consumer op with new outerMost loop's
+  // results.
   for (auto &&[oldResult, newResult] :
        llvm::zip(consumerOp->getResults(),
-                 newLoopOp->getResults().drop_front(initSize))) {
+                 newOuterLoops.front()->getResults().take_back(
+                     newInitAppend.size()))) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
-  // 13. Need to erase the old scf loop and the cloned consumer op.
-  rewriter.eraseOp(oldLoopOp);
+  // 13. Need to erase the cloned consumer op.
   rewriter.eraseOp(clonedConsumerOp);
 
   return scf::SCFFuseConsumerOfSliceResult{
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 400b558e37fcd..d60d5f4fd7b3c 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -315,3 +315,99 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+  func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+      %iv0 = affine.apply #map(%arg3)
+      %iv1 = affine.apply #map(%arg4)
+      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+          scf.yield %insert_slice : tensor<128x128xf32>
+        }
+        scf.yield %3 : tensor<128x128xf32>
+      }
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+      }
+    }
+    %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    return %5 : tensor<256x256xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+//      CHECK: #[[MAP0:.*]] =  affine_map<(d0) -> (d0 * 128)>
+//      CHECK: #[[MAP1:.*]] =  affine_map<(d0, d1) -> (d0 + d1 * 128)>
+//      CHECK: func.func @fuse_tilable_consumer_nested_scf_loop(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[dest1:.*]] = linalg.fill
+// CHECK-SAME:          outs(%[[dest0]] :
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG0:.*]] = %[[dest1]], %[[SECOND_OUT_ARG0:.*]] = %[[dest0]])
+// CHECK-SAME:   {
+//      CHECK:      %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
+//      CHECK:      %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
+//      CHECK:      %[[MAT_OUT_SLICE0:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
+//      CHECK:      %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
+//      CHECK:      %[[ADD_OUT_SLICE0:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV3:.*]] = %[[C0]]
+// CHECK-SAME:          iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[MAT_OUT_SLICE0]], %[[SECOND_OUT_ARG1:.*]] = %[[ADD_OUT_SLICE0]])
+// CHECK-SAME:      {
+//      CHECK:          %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV4:.*]] = %[[C0]]
+// CHECK-SAME:            iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
+// CHECK-SAME:          {
+//      CHECK:            %[[MAT_OUT_SLICE1:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
+//      CHECK:            %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
+//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:                  outs(%[[MAT_OUT_SLICE1]] :
+//      CHECK:            %[[REAL_SECOND_IV1:.*]] = affine.apply #[[MAP1]](%[[IV3]], %[[IV1]])
+//      CHECK:            %[[REAL_SECOND_IV2:.*]] = affine.apply #[[MAP1]](%[[IV4]], %[[IV2]])
+//      CHECK:            %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[REAL_SECOND_IV1]], %[[REAL_SECOND_IV2]]] [64, 64] [1, 1]
+//      CHECK:            %[[ADD_OUT_SLICE1:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
+// CHECK-SAME:              outs(%[[ADD_OUT_SLICE1]] :
+//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
+//      CHECK:          }
+//      CHECK:          scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
+//      CHECK:      }
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]]#1 into %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]]#0 into %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#1 :

>From be282cd1e917c426509cff631ba31960313621ad Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Mon, 24 Jun 2024 18:22:21 -0700
Subject: [PATCH 2/3] Revert "extend consumer fuse to nested scf loop (v1)"

This reverts commit 5ba402530cd9909ae76136d8d253e4897987dfb7.
---
 .../SCF/Transforms/TileUsingInterface.cpp     | 731 ++++++------------
 .../tile-and-fuse-consumer.mlir               |  96 ---
 2 files changed, 221 insertions(+), 606 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 8a7eac1d883f9..a1392813d6de3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1198,6 +1197,98 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
 // tileAndFuseConsumerUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
+/// A utility function that checks whether the only use of the result of a
+/// tensor.insert_slice op is in a scf.yield op.
+static LogicalResult
+checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
+  Value result = candidateSliceOp.getResult();
+  Value::use_range uses = result.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return failure();
+  }
+  OpOperand &operandUse = (*uses.begin());
+  Operation *userOp = operandUse.getOwner();
+  if (!isa<scf::YieldOp>(userOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Expected scf.yield to be the only user, but got -> "
+               << (*userOp));
+    return failure();
+  }
+  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+                               "be in the same block\n");
+    return failure();
+  }
+  return success();
+}
+
+/// Fetches the OpOperand of the only user (and use) of the value `val` which
+/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
+/// failure otherwise.
+static FailureOr<OpOperand *> getConsumerFromUses(Value val,
+                                                  Block *containingOpBlock) {
+  // Step 1. Check that the value has exactly one use.
+  if (!llvm::hasSingleElement(val.getUses()))
+    return failure();
+  // Step 2. Get uses.
+  OpOperand &operand = (*val.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp))
+    return failure();
+  if (containingOpBlock != consumerOp->getBlock())
+    return failure();
+  return &operand;
+}
+
+/// Fetch the untiled consumer of a scf.for's result which is yielded by a
+/// tensor.insert_slice. This function makes the following assumptions :
+/// 1.  tensor.insert_slice has scf.yield as its only user.
+/// 2.  scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
+    return failure();
+  Value sliceResult = candidateSliceOp.getResult();
+  // Step 1. Fetch the corresponding output.
+  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+  unsigned resultNumber = yieldOpOperand.getOperandNumber();
+  // Step 2. Check containing op is scf.for.
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp)
+    return failure();
+  Value resultingValue = forOp->getResult(resultNumber);
+
+  return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // Step 1. Fetch the corresponding output
+  Value sliceDest = candidateSliceOp.getDest();
+  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+  if (!iterArg)
+    return failure();
+  Operation *containingOp = iterArg.getOwner()->getParentOp();
+  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
+    return failure();
+  // Step 2. Check that the containing op is scf.forall.
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp)
+    return failure();
+  Value resultingValue =
+      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+
+  return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
 /// This utility currently checks whether the loop either :-
 /// 1. Yields exactly one result.
 /// 2. Has consumer op as its first user and other users to be in the same
@@ -1223,117 +1314,31 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
   return success();
 }
 
-/// Traverse and collect all outer loops of given sliceOp, sorted by
-/// outer-to-inner. If `untilLoop` found, stop walk through in advance.
-static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
-    Operation *sliceOp,
-    std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
-  assert(isa<OffsetSizeAndStrideOpInterface>(sliceOp));
-  SmallVector<LoopLikeOpInterface> outerLoops;
-  auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
-  while (forOp) {
-    outerLoops.push_back(forOp);
-    if (untilLoop.has_value() && *untilLoop == forOp)
-      break;
-    forOp = forOp->getParentOfType<LoopLikeOpInterface>();
-  }
-  return {outerLoops.rbegin(), outerLoops.rend()};
-}
-
-/// Get the result of top-level loop which yields the target InsertSliceOp. E.g
-/// ```
-/// %1 = scf.for
-///  %2 = scf.for
-///   %3 = scf.for
-///      ...
-///      %4 = insert
-///      yield %4
-///   %5 = insert %3
-///   yield %5
-///  yield %2
-/// ```
-/// @param targetSliceOp: %4 = insert
-/// @return first: resultValue: %1
-///         second: Collected insertSliceOp List during walk including
-///                   targetSliceOp: %4 = insert and %5 = insert %3
-static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
-getResultOfTopLevelLoopYieldInsertSliceOp(Operation *targetSliceOp,
-                                          int curDepth = 0, int maxDepth = 5) {
-  assert(isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
-  // Control recursive time in avoid of stack overflow
-  if (curDepth > maxDepth)
-    return failure();
-
-  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
-  candidateSliceOpList.push_back(
-      cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
-  Value resultOfLoop;
-  if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
-    Value destValue = sliceOp.getDest();
-    auto iterArg = cast<BlockArgument>(destValue);
-    auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
-    if (!forallOp)
-      return failure();
-    resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-  } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
-    Value resultValue = sliceOp.getResult();
-    for (auto &useOperand : resultValue.getUses()) {
-      if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
-        if (llvm::detail::isPresent(resultOfLoop))
-          return failure();
-        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
-        if (!forOp)
-          return failure();
-        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
-      }
-    }
-  }
-
-  if (!llvm::detail::isPresent(resultOfLoop))
+/// A utility to fetch an untiled consumer of
+/// tensor.insert_slice/tensor.parallel_insert_slice.
+static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
+  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(insertSlice);
+  } else if (auto parallelInsertSlice =
+                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(parallelInsertSlice);
+  } else {
     return failure();
-
-  while (true) {
-    bool walkThroughOuterLoop = false;
-    for (OpOperand &useOperand : resultOfLoop.getUses()) {
-      if (auto sliceOp =
-              dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
-        FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
-            resultAndSliceOpsPair = getResultOfTopLevelLoopYieldInsertSliceOp(
-                sliceOp, curDepth + 1);
-        if (failed(resultAndSliceOpsPair))
-          return failure();
-        candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
-                                    (*resultAndSliceOpsPair).second.end());
-        return std::make_pair((*resultAndSliceOpsPair).first,
-                              candidateSliceOpList);
-      } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
-        // walk through outer loop
-        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
-        if (!forOp)
-          return failure();
-        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
-        walkThroughOuterLoop = true;
-        break;
-      }
-    }
-    if (!walkThroughOuterLoop)
-      break;
   }
-  return std::make_pair(resultOfLoop, candidateSliceOpList);
 }
 
 /// After fusing consumer into scf.for we want to modify the scf.yield operation
 /// to reflect the same by returning the values yielded by the tiled consumer.
 static void
 fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
-                      ResultRange tilingResult,
-                      ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
-                      ArrayRef<SmallVector<OpFoldResult>> resultSizes,
+                      TilingResult &tilingResult,
+                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
                       ArrayRef<BlockArgument> bbArgs) {
   scf::YieldOp oldTerminatorOp =
       cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
   unsigned totalOldResults = oldTerminatorOp->getNumResults();
-  unsigned totalTiledResults = tilingResult.size();
+  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
   SmallVector<Value> newYieldOperands;
   newYieldOperands.reserve(totalOldResults + totalTiledResults);
   for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1342,7 +1347,8 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
   rewriter.setInsertionPointAfter(oldTerminatorOp);
   Location loc = newForOp.getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
+                       resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1357,16 +1363,16 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
 /// values by the tiled consumer within scf.forall.in_parallel region.
 static void
 fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
-                           ResultRange tilingResult,
-                           ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
-                           ArrayRef<SmallVector<OpFoldResult>> resultSizes,
+                           SmallVector<Value> tiledResults,
+                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
                            ArrayRef<BlockArgument> bbArgs) {
   scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
   rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
   Location firstYieldOpLoc =
       (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
+       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1374,176 +1380,6 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
   }
 }
 
-/// If the top level loop of nested loop structure is scf.forall, need to create
-/// additional tensor.extract_slice for its new appended `shared_outs` in order
-/// to pass correct local memory for inner loops. E.g.
-///
-/// scf.forall shared_outs(%o1=..., %o2=...) {
-///     %local_o1 = extract_slice %o1
-///     // fix new appended `shared_out` %o2
-///     %local_o2 = extract_slice %o2
-///     scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
-///        ...
-///     }
-///     ...
-/// }
-static SmallVector<tensor::ExtractSliceOp> fixLoopInitFromSharedOutSCFForall(
-    RewriterBase &rewriter, Operation *loop, ArrayRef<Value> newSharedOuts,
-    ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
-    ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
-  rewriter.setInsertionPoint(loop);
-  Location loc = loop->getLoc();
-  // create new ExtractOps for newInits from scf.forall
-  SmallVector<tensor::ExtractSliceOp> newExtractOps;
-  newExtractOps.reserve(resultOffsets.size());
-  for (auto [bbArg, offset, sizes] :
-       llvm::zip_equal(newSharedOuts, resultOffsets, resultSizes)) {
-    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
-    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
-        loc, bbArg, offset, sizes, strides);
-    newExtractOps.push_back(newExtractOp);
-  }
-  return newExtractOps;
-}
-
-/// If outerMost loop of nested loop structure is `scf.forall`, need to deal
-/// with DpsInit of tiled consumer
-static void
-fixDpsInitsOfTiledConsumer(RewriterBase &rewriter, Operation *tiledConsumer,
-                           ArrayRef<BlockArgument> bbArgs,
-                           ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
-                           ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
-  rewriter.setInsertionPoint(tiledConsumer);
-  Location loc = tiledConsumer->getLoc();
-  for (auto &&[bbArg, offset, sizes, dpsInit] :
-       llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
-                       cast<DestinationStyleOpInterface>(tiledConsumer)
-                           .getDpsInitsMutable())) {
-    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
-    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
-        loc, bbArg, offset, sizes, strides);
-    dpsInit.set(newExtractOp.getResult());
-  }
-}
-
-/// Compute all results tile by given SliceOp along operand
-static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
-    RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
-    OffsetSizeAndStrideOpInterface ossSliceOp,
-    SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
-    SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
-  // 1. Check all stride all 1
-  if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
-        return !isConstantIntValue(stride, 1);
-      })) {
-    return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
-  }
-  // 2. Compute iteration domain tile from the input position
-  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-  if (failed(tilableOp.getIterationDomainTileFromOperandTile(
-          rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
-          ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
-    return rewriter.notifyMatchFailure(
-        tilableOp, "can't get iter domain position from input position");
-  }
-  unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
-  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-      totalNumResultsOfConsumer);
-  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
-  // 3. Compute result Tile by resultNumber
-  for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
-    if (failed(tilableOp.getResultTilePosition(
-            rewriter, idx, iterDomainOffsets, iterDomainSizes,
-            resultOffsets[idx], resultSizes[idx]))) {
-      return rewriter.notifyMatchFailure(
-          tilableOp,
-          "can't get result domain position from iter domain position");
-    }
-  }
-  allResultOffsets = resultOffsets;
-  allResultSizes = resultSizes;
-  return success();
-}
-
-/// Considering multi-level tensor.*SliceOp maybe based on different
-/// coordination, this utility computes the real OFFSET coordinated on ROOT
-/// SliceOp. E.g
-///             %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
-///         %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
-///
-/// where the coordination can be illustrated as follow:
-///
-///  %3 ----------------------------------
-///  |         |         |
-///  | OFFSET2 | OFFSET1 |
-///  | ------ %0         |
-///  |                   |
-///  |                   |
-///  |------------------ %1 ------ |
-///  |                   |  SIZE1  |
-///  |                   |         |
-///  |                   |         |
-///  |                   | ------- |
-///  |
-///
-/// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
-static FailureOr<SmallVector<OpFoldResult>>
-computeRealOffsetsCoordinatedRootSliceOp(
-    RewriterBase &rewriter, Location loc,
-    OffsetSizeAndStrideOpInterface candidateSliceOp,
-    MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
-  if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
-        return !isConstantIntValue(stride, 1);
-      })) {
-    return rewriter.notifyMatchFailure(candidateSliceOp,
-                                       "candidateSliceOp has stride");
-  }
-  SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
-  // Real offsets equals to accumulative offsets of outer candidates
-  for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
-       iter++) {
-    // assert each outer candidate slice has no stride
-    if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
-          return !isConstantIntValue(stride, 1);
-        })) {
-      return failure();
-    }
-    for (auto &&[ofr1, ofr2] :
-         llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
-      using AVE = affine::AffineValueExpr;
-      affine::AffineBuilder ab(rewriter, loc);
-      AffineExpr dim0, dim1, sym;
-      bindDims(rewriter.getContext(), dim0, dim1);
-      bindSymbols(rewriter.getContext(), sym);
-      auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
-      ofr1 = ab.add(aveOffset1, aveOffset2);
-    }
-  }
-  return realOffsets;
-}
-
-/// Get the first tilable user of given Value and check its domination at the
-/// same time
-static FailureOr<OpOperand *>
-getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
-  for (auto &useOfval : val.getUses()) {
-    Operation *consumerOp = useOfval.getOwner();
-    // 1. Check whether consumerOp is tilable
-    if (!isa<TilingInterface>(consumerOp) ||
-        !isa<DestinationStyleOpInterface>(consumerOp))
-      continue;
-    // 2. Check stay in same block with loopOp
-    if (loopOp->getBlock() != consumerOp->getBlock())
-      continue;
-    // 3. Check no other user before it
-    if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
-      continue;
-    }
-    return &useOfval;
-  }
-  return failure();
-}
-
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1555,28 +1391,10 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
 
-  // 1.a. Get the real consumer of candidate
-  // tensor.insert_slice/parallel_insert_slice by walking through
-  // scf.for/scf.forall and collect all [Parallel]insertSliceOp(s) along the
-  // way.
-  FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
-      resultAndSliceOpsPair =
-          getResultOfTopLevelLoopYieldInsertSliceOp(candidateSliceOp);
-  if (failed(resultAndSliceOpsPair)) {
-    return rewriter.notifyMatchFailure(candidateSliceOp,
-                                       "could not fetch consumer to fuse");
-  }
-
-  // 1.b. Get all outer loops of candidateSliceOp.
-  SmallVector<LoopLikeOpInterface> outerLoops = getOuterLoopsOfSliceOp(
-      candidateSliceOp, dyn_cast<OpResult>((*resultAndSliceOpsPair).first)
-                            .getDefiningOp<LoopLikeOpInterface>());
-  LoopLikeOpInterface outerMostLoop = outerLoops.front();
-
-  // 2. Get first tilable consumer op
+  // 1. Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice/parallel_insert_slice.
   FailureOr<OpOperand *> maybeConsumerOpOperand =
-      getTilableConsumerOperandFirstUseVal((*resultAndSliceOpsPair).first,
-                                           outerMostLoop);
+      getUntiledConsumerFromSlice(candidateSliceOp);
   if (failed(maybeConsumerOpOperand)) {
     return rewriter.notifyMatchFailure(candidateSliceOp,
                                        "could not fetch consumer to fuse");
@@ -1592,193 +1410,111 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
   }
 
-  // 3. Check consumer is not using outerMostLoop's output as init.
+  Operation *oldLoopOp = nullptr;
+  SmallVector<Value> newOuts;
+  Block *oldLoopBody = nullptr;
+  unsigned initSize = 0;
+  unsigned rank = 1;
+  if (isInsertSliceOp) {
+    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+    oldLoopOp = forOp;
+    llvm::append_range(newOuts, forOp.getInits());
+    oldLoopBody = forOp.getBody();
+    initSize = forOp.getInits().size();
+  } else {
+    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
+    oldLoopOp = forallOp;
+    llvm::append_range(newOuts, forallOp.getOutputs());
+    oldLoopBody = forallOp.getBody();
+    initSize = forallOp.getOutputs().size();
+    rank = forallOp.getRank();
+  }
+
+  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+    return rewriter.notifyMatchFailure(
+        oldLoopOp, "containing loop op should either yield just one value or "
+                   "have the consumer op as its first user");
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+
+  // 2. Check consumer is not using scf loop's output as init.
   auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
+  newOuts.append(dpsInits);
 
-  // 4.a. Reconstruct nested loop from outer to inner.
-  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList =
-      (*resultAndSliceOpsPair).second;
-  SmallVector<LoopLikeOpInterface> newOuterLoops;
-  SmallVector<SmallVector<OpFoldResult>> allResultOffsets, allResultSizes;
-  SmallVector<tensor::ExtractSliceOp> newExtractOps;
-
-  Block *oldLoopBody = nullptr;
-  Block *newLoopBody = nullptr;
-  SmallVector<Value> newInitAppend = dpsInits, newOuts;
+  Location loc = oldLoopOp->getLoc();
 
-  OpBuilder::InsertionGuard g(rewriter);
-  // 4.b. Set insertPoint right before consumerOp
+  // 3. Create new scf loop op.
   rewriter.setInsertionPoint(consumerOp);
-
-  for (auto [index, loop] :
-       llvm::enumerate(MutableArrayRef(outerLoops).drop_back())) {
-    if (index > 0) {
-      rewriter.setInsertionPoint(loop);
-      // 4.c. Create `extractSliceOp` for newInits if they comes from sharedOut
-      // of previous `scf.forall` loop.
-      if (auto prevOuterLoop =
-              dyn_cast<scf::ForallOp>(newOuterLoops.back().getOperation())) {
-        if (index != 1) {
-          return rewriter.notifyMatchFailure(
-              prevOuterLoop, "Currently only outerMostLoop assumed forallOp");
-        }
-        OffsetSizeAndStrideOpInterface outerMostCandidate =
-            candidateSliceOpList.back();
-        if (failed(computeAllResultTileForOpGivenOperandSliceOp(
-                rewriter, cast<TilingInterface>(consumerOp), operandNumber,
-                outerMostCandidate, allResultOffsets, allResultSizes))) {
-          return failure();
-        }
-        newExtractOps = fixLoopInitFromSharedOutSCFForall(
-            rewriter, loop, newInitAppend, allResultOffsets, allResultSizes);
-        newInitAppend = llvm::map_to_vector(
-            newExtractOps,
-            [](tensor::ExtractSliceOp op) -> Value { return op.getResult(); });
-      }
-    }
-    LoopLikeOpInterface newLoopOp;
-    // 4.d. Create a new loop with the new init values for this loop.
-    if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
-      newOuts = llvm::to_vector(forOp.getInits());
-      newOuts.append(newInitAppend.begin(), newInitAppend.end());
-      auto newLoop = rewriter.create<scf::ForOp>(
-          forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-          forOp.getStep(), newOuts);
-      newLoopOp = newLoop;
-      oldLoopBody = forOp.getBody();
-      newLoopBody = newLoop.getBody();
-    } else if (auto forallOp = dyn_cast<scf::ForallOp>(loop.getOperation())) {
-      newOuts = llvm::to_vector(forallOp.getOutputs());
-      newOuts.append(newInitAppend.begin(), newInitAppend.end());
-      auto newLoop = rewriter.create<scf::ForallOp>(
-          forallOp.getLoc(), forallOp.getMixedLowerBound(),
-          forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
-          forallOp.getMapping());
-      rewriter.eraseOp(newLoop.getTerminator());
-      newLoopOp = newLoop;
-      oldLoopBody = forallOp.getBody();
-      newLoopBody = newLoop.getBody();
-    }
-    newInitAppend = llvm::map_to_vector(
-        newLoopBody->getArguments().take_back(newInitAppend.size()),
-        [](BlockArgument bArg) -> Value { return bArg; });
-    rewriter.mergeBlocks(
-        oldLoopBody, newLoopBody,
-        newLoopBody->getArguments().take_front(oldLoopBody->getNumArguments()));
-    rewriter.replaceOp(
-        loop, newLoopOp->getResults().take_front(loop->getNumResults()));
-    newOuterLoops.push_back(newLoopOp);
-  }
-
-  // 5.a. Reconstruct inner-most loop.
-  LoopLikeOpInterface oldInnerMostLoop = outerLoops.back(), newInnerMostLoop;
-  Location loc = oldInnerMostLoop->getLoc();
-  if (outerLoops.size() > 1)
-    rewriter.setInsertionPoint(oldInnerMostLoop);
-
+  Operation *newLoopOp = nullptr;
+  Block *newLoopBody = nullptr;
   if (isInsertSliceOp) {
-    auto forOp = cast<scf::ForOp>(oldInnerMostLoop.getOperation());
-    newOuts = llvm::to_vector(forOp.getInits());
-    newOuts.append(newInitAppend.begin(), newInitAppend.end());
-    oldLoopBody = forOp.getBody();
+    auto forOp = cast<scf::ForOp>(oldLoopOp);
     auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
                                                 forOp.getUpperBound(),
                                                 forOp.getStep(), newOuts);
-    newInnerMostLoop = newForOp;
+    newLoopOp = newForOp;
     newLoopBody = newForOp.getBody();
   } else {
-    auto forallOp = cast<scf::ForallOp>(oldInnerMostLoop.getOperation());
-    newOuts = llvm::to_vector(forallOp.getOutputs());
-    newOuts.append(newInitAppend.begin(), newInitAppend.end());
-    oldLoopBody = forallOp.getBody();
+    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
     auto newForallOp = rewriter.create<scf::ForallOp>(
         loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
         forallOp.getMixedStep(), newOuts, forallOp.getMapping());
-    newInnerMostLoop = newForallOp;
+    newLoopOp = newForallOp;
     rewriter.eraseOp(newForallOp.getTerminator());
     newLoopBody = newForallOp.getBody();
   }
 
-  // 5.b. Move the loop body to the new op.
+  // 4. Move the loop body to the new op.
   unsigned oldNumArguments = oldLoopBody->getNumArguments();
   rewriter.mergeBlocks(oldLoopBody, newLoopBody,
                        newLoopBody->getArguments().take_front(oldNumArguments));
-  // 5.c. Replace the result of old oldInnerMostLoop with newInnerMostLoop's
-  // results.
-  rewriter.replaceOp(oldInnerMostLoop,
-                     newInnerMostLoop->getResults().take_front(
-                         oldInnerMostLoop->getNumResults()));
 
-  // 6. Set insertion point before terminator op of the loop and create a new
+  // 5. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
   tensor::InsertSliceOp clonedInsertSliceOp;
   if (auto sliceOp =
           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
-    auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
+    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
     rewriter.setInsertionPoint(newForallOp.getTerminator());
+    clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
+        sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
   } else {
     rewriter.setInsertionPoint(candidateSliceOp);
+    clonedInsertSliceOp =
+        cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
   }
 
-  // 7.a. Due to current assumption of `getTiledImplementation` that all
-  // `Operands` are untiled with original tensor size, create dummy
-  // `insertSliceOp` to align with that requirement.
-  auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
-  FailureOr<SmallVector<OpFoldResult>> realOffsets =
-      computeRealOffsetsCoordinatedRootSliceOp(rewriter, loc, ossSliceOp,
-                                               candidateSliceOpList);
-  if (failed(realOffsets))
-    return failure();
-  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-      loc, candidateSliceOp->getOperand(0),
-      candidateSliceOpList.back()->getOperand(1), *realOffsets,
-      ossSliceOp.getMixedSizes(), ossSliceOp.getMixedStrides());
-
-  SmallVector<Value> newDpsInitsForConsumerDest = llvm::map_to_vector(
-      newLoopBody->getArguments().drop_front(oldNumArguments),
-      [](BlockArgument bArg) -> Value { return bArg; });
-
-  // 7.b. If the outerMostLoop is scf.forall, then the `newExtractOps` has been
-  // additionally created at `step 4.e` for `dpsInits`. As the counterpart, the
-  // `insertSliceOp` is also needed for the same purpose with `step 7.a`.
-  if (!newExtractOps.empty()) {
-    for (auto &&[extractOp, newDpsInit] :
-         llvm::zip_equal(newExtractOps, newDpsInitsForConsumerDest)) {
-      auto alignDpsInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-          loc, newDpsInit, extractOp.getSource(), extractOp.getMixedOffsets(),
-          extractOp.getMixedSizes(), extractOp.getMixedStrides());
-      newDpsInit = alignDpsInsertSliceOp.getResult();
-    }
-  }
-
-  // 8.a. Clone consumer op.
+  // 6.a. Clone consumer op.
+  auto newForOpBlockArgsForConsumerDest =
+      newLoopBody->getArguments().drop_front(oldNumArguments);
   auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
-      rewriter, consumerOp, newDpsInitsForConsumerDest));
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
 
-  // 8.b. Replace all uses of the loop result with the result of the cloned
+  // 6.b. Replace all uses of the loop result with the result of the cloned
   // tensor.insert_slice.
   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
     operandToReplace.set(clonedInsertSliceOp.getResult());
   });
 
-  // 9. Perform tiling of the cloned consumer and replace the operand at
+  // 7 - Perform tiling of the cloned consumer and replace the operand at
   // `operandNumber` with the source of the cloned tensor.insert_slice op.
+  auto ossSliceOp =
+      cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
   FailureOr<TilingResult> tileAndFuseResult =
       tensor::replaceInsertSliceWithTiledConsumer(
-          rewriter,
-          cast<OffsetSizeAndStrideOpInterface>(
-              clonedInsertSliceOp.getOperation()),
-          clonedConsumerOp->getOpOperand(operandNumber));
+          rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
   if (failed(tileAndFuseResult)) {
     return failure();
   }
@@ -1786,100 +1522,75 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
       tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
       clonedInsertSliceOp.getSource());
 
-  // 10. Try to fetch the offset and size for all results of the cloned
-  // consumer. This would then be used to form the corresponding
-  // tensor.insert_slice/parallel_insert_slice later.
-  if (failed(computeAllResultTileForOpGivenOperandSliceOp(
-          rewriter, clonedConsumerOp, operandNumber, ossSliceOp,
-          allResultOffsets, allResultSizes))) {
-    return failure();
+  // 8 - Extract offset/sizes/strides required to create the
+  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+
+  // 9. Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "containingOp's result yield with stride");
+  }
+
+  // 10. Try to get iter domain position from input position.
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        clonedConsumerOp, "can't get iter domain position from input position");
   }
 
-  if (!newExtractOps.empty()) {
-    fixDpsInitsOfTiledConsumer(
-        rewriter, tileAndFuseResult->tiledOps[0],
-        newLoopBody->getArguments().drop_front(oldNumArguments),
-        allResultOffsets, allResultSizes);
+  // 11. Try to fetch the offset and size for all results of the cloned
+  // consumer. This would then be used to form the corresponding
+  // tensor.insert_slice/parallel_insert_slice later.
+  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(clonedConsumerOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
+      return rewriter.notifyMatchFailure(
+          clonedConsumerOp,
+          "can't get result domain position from iter domain position");
+    }
   }
 
+  auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
+  auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
   if (isInsertSliceOp) {
-    auto newForOp = cast<scf::ForOp>(newInnerMostLoop);
+    auto newForOp = cast<scf::ForOp>(newLoopOp);
     fixTerminatorSCFYield(
-        rewriter, newForOp, tileAndFuseResult->tiledOps[0]->getResults(),
-        allResultOffsets, allResultSizes,
-        newForOp.getBody()->getArguments().take_back(newInitAppend.size()));
+        rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
+        newForOp.getBody()->getArguments().drop_front(1 + initSize));
   } else {
-    auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
+    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
     fixTerminatorSCFInParallel(
         rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
-        allResultOffsets, allResultSizes,
-        newForallOp.getBody()->getArguments().take_back(newInitAppend.size()));
+        arrayRefOffsets, arrayRefSizes,
+        newForallOp.getBody()->getArguments().drop_front(rank + initSize));
   }
 
-  newOuterLoops.push_back(cast<LoopLikeOpInterface>(newInnerMostLoop));
-
-  // 11.a. Reconstruct terminator of outer loop by inner loop.
-  auto outerCandidateIter = candidateSliceOpList.rbegin();
-  for (auto [outerLoop, innerLoop] :
-       llvm::zip_equal(MutableArrayRef(newOuterLoops).drop_back(),
-                       MutableArrayRef(newOuterLoops).drop_front())) {
-    // 11.b. Create insertSliceOp according outer candidateSliceOp
-    if (outerCandidateIter != candidateSliceOpList.rend() &&
-        outerCandidateIter->getOperation()
-                ->getParentOfType<LoopLikeOpInterface>() == outerLoop) {
-      if (auto forallOp = dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
-        rewriter.setInsertionPoint(forallOp.getTerminator());
-      } else {
-        rewriter.setInsertionPointAfter(*outerCandidateIter);
-      }
-
-      if (failed(computeAllResultTileForOpGivenOperandSliceOp(
-              rewriter, clonedConsumerOp, operandNumber, *outerCandidateIter,
-              allResultOffsets, allResultSizes))) {
-        return failure();
-      }
-
-      if (auto forOp = dyn_cast<scf::ForOp>(outerLoop.getOperation())) {
-        fixTerminatorSCFYield(
-            rewriter, forOp,
-            innerLoop->getResults().take_back(newInitAppend.size()),
-            allResultOffsets, allResultSizes,
-            forOp.getBody()->getArguments().take_back(newInitAppend.size()));
-      } else if (auto forallOp =
-                     dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
-        fixTerminatorSCFInParallel(
-            rewriter, forallOp,
-            innerLoop->getResults().take_back(newInitAppend.size()),
-            allResultOffsets, allResultSizes,
-            forallOp.getBody()->getArguments().take_back(newInitAppend.size()));
-      }
-      outerCandidateIter++;
-    } else {
-      // 11.c. Yield additional new appended results of innerLoop
-      assert(isa<scf::ForOp>(outerLoop));
-      auto forOp = cast<scf::ForOp>(outerLoop);
-      auto outerLoopYield =
-          cast<scf::YieldOp>(forOp.getBody()->getTerminator());
-      SmallVector<Value> newYields =
-          llvm::to_vector(outerLoopYield.getOperands());
-      ValueRange additionalYields =
-          innerLoop->getResults().take_back(newInitAppend.size());
-      newYields.append(additionalYields.begin(), additionalYields.end());
-      rewriter.setInsertionPoint(outerLoopYield);
-      rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
-    }
+  // 12. Replace the result of scf loop and consumer op with new loop's results.
+  for (auto &&[oldResult, newResult] :
+       llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
+    rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
-  // 12. Replace the result of consumer op with new outerMost loop's
-  // results.
   for (auto &&[oldResult, newResult] :
        llvm::zip(consumerOp->getResults(),
-                 newOuterLoops.front()->getResults().take_back(
-                     newInitAppend.size()))) {
+                 newLoopOp->getResults().drop_front(initSize))) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
-  // 13. Need to erase the cloned consumer op.
+  // 13. Need to erase the old scf loop and the cloned consumer op.
+  rewriter.eraseOp(oldLoopOp);
   rewriter.eraseOp(clonedConsumerOp);
 
   return scf::SCFFuseConsumerOfSliceResult{
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index d60d5f4fd7b3c..400b558e37fcd 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -315,99 +315,3 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#1 :
-
-// -----
-
-#map = affine_map<(d0) -> (d0 * 128)>
-module {
-  func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
-    %c0 = arith.constant 0 : index
-    %c64 = arith.constant 64 : index
-    %c128 = arith.constant 128 : index
-    %cst = arith.constant 0.000000e+00 : f32
-    %dest0 = tensor.empty() : tensor<256x256xf32>
-    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
-    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
-      %iv0 = affine.apply #map(%arg3)
-      %iv1 = affine.apply #map(%arg4)
-      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
-      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
-      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
-      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
-        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
-          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
-          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
-          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
-          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
-          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
-          scf.yield %insert_slice : tensor<128x128xf32>
-        }
-        scf.yield %3 : tensor<128x128xf32>
-      }
-      scf.forall.in_parallel {
-         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
-      }
-    }
-    %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
-    return %5 : tensor<256x256xf32>
-  }
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %slice_op
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-
-//      CHECK: #[[MAP0:.*]] =  affine_map<(d0) -> (d0 * 128)>
-//      CHECK: #[[MAP1:.*]] =  affine_map<(d0, d1) -> (d0 + d1 * 128)>
-//      CHECK: func.func @fuse_tilable_consumer_nested_scf_loop(
-// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
-// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
-// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
-//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
-//      CHECK:   %[[dest1:.*]] = linalg.fill
-// CHECK-SAME:          outs(%[[dest0]] :
-//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
-// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG0:.*]] = %[[dest1]], %[[SECOND_OUT_ARG0:.*]] = %[[dest0]])
-// CHECK-SAME:   {
-//      CHECK:      %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
-//      CHECK:      %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
-//      CHECK:      %[[MAT_OUT_SLICE0:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
-//      CHECK:      %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
-//      CHECK:      %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
-//      CHECK:      %[[ADD_OUT_SLICE0:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
-//      CHECK:      %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV3:.*]] = %[[C0]]
-// CHECK-SAME:          iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[MAT_OUT_SLICE0]], %[[SECOND_OUT_ARG1:.*]] = %[[ADD_OUT_SLICE0]])
-// CHECK-SAME:      {
-//      CHECK:          %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV4:.*]] = %[[C0]]
-// CHECK-SAME:            iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
-// CHECK-SAME:          {
-//      CHECK:            %[[MAT_OUT_SLICE1:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
-//      CHECK:            %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
-//      CHECK:            %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
-//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
-// CHECK-SAME:                  outs(%[[MAT_OUT_SLICE1]] :
-//      CHECK:            %[[REAL_SECOND_IV1:.*]] = affine.apply #[[MAP1]](%[[IV3]], %[[IV1]])
-//      CHECK:            %[[REAL_SECOND_IV2:.*]] = affine.apply #[[MAP1]](%[[IV4]], %[[IV2]])
-//      CHECK:            %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[REAL_SECOND_IV1]], %[[REAL_SECOND_IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[ADD_OUT_SLICE1:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
-//      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
-// CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
-// CHECK-SAME:              outs(%[[ADD_OUT_SLICE1]] :
-//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
-//      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
-//      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
-//      CHECK:          }
-//      CHECK:          scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
-//      CHECK:      }
-//      CHECK:      scf.forall.in_parallel {
-//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]]#1 into %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
-//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]]#0 into %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
-//      CHECK:       }
-//      CHECK:   }
-//      CHECK:   return %[[FINAL_RESULT]]#1 :

>From a015bf75325e22ababc627e664365ea60a09fa12 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Tue, 25 Jun 2024 01:58:59 -0700
Subject: [PATCH 3/3] extend consumer fuse to nested scf loop (v2)

---
 .../SCF/Transforms/TileUsingInterface.h       |   1 +
 .../SCF/Transforms/TileUsingInterface.cpp     | 416 +++++++++++++++---
 .../tile-and-fuse-consumer.mlir               |  98 ++++-
 3 files changed, 462 insertions(+), 53 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index d68ca11207376..d6e8fbfab085a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -254,6 +254,7 @@ struct SCFFuseConsumerOfSliceResult {
       *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
   SmallVector<Operation *> tiledOps;
 };
+
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
 tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a1392813d6de3..a5c85c642aca2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1223,26 +1223,86 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
   return success();
 }
 
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
+/// Fetches the FIRST OpOperand of the tilable user (and use) of the value `val`
+/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
+/// Returns failure otherwise.
 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
                                                   Block *containingOpBlock) {
-  // Step 1. Check that the value has exactly one use.
-  if (!llvm::hasSingleElement(val.getUses()))
-    return failure();
-  // Step 2. Get uses.
-  OpOperand &operand = (*val.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.for, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp))
-    return failure();
-  if (containingOpBlock != consumerOp->getBlock())
+  OpOperand *operand = nullptr;
+  for (auto &use : val.getUses()) {
+    Operation *user = use.getOwner();
+    // Step 1. Check if the user is tilable.
+    if (!isa<TilingInterface, DestinationStyleOpInterface>(user)) {
+      // TODO: We have to init result of consumer before scf.for, use
+      //       DestinationStyleOpInterface to get result shape from init for
+      //       now. Add support for other op such as op has
+      //       InferTypeOpInterface.
+      continue;
+    } else {
+      // Step 2. Check if user stay in the same block.
+      if (containingOpBlock != user->getBlock())
+        continue;
+      // Step 3. Check if user has succeeding user. Otherwise, it usually
+      // represents already tiled.
+      if (user->use_empty())
+        continue;
+      operand = &use;
+      break;
+    }
+  }
+  if (!operand) {
     return failure();
-  return &operand;
+  }
+  return operand;
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+///              if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+///  %0 = scf.for()
+///    %1 = scf.for()
+///      %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
+    LoopLikeOpInterface loop,
+    const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
+  SmallVector<LoopLikeOpInterface> nestLoops = {loop};
+  auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
+  while (outerLoop && succeeded(pred(outerLoop))) {
+    nestLoops.push_back(outerLoop);
+    outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
+  }
+  // sorted from outer to inner
+  return {nestLoops.rbegin(), nestLoops.rend()};
+}
+
+/// Check if it is the ForOp that yield the result of inner loop
+static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
+  if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
+    for (auto &&[index, op] :
+         llvm::enumerate(forOp.getBody()->getOperations())) {
+      // If the orderIndex of inner loop is the last second one before the
+      // yieldOp of ForOp, the given loop must yield the result of inner loop.
+      if (isa<LoopLikeOpInterface>(op)) {
+        return (index + 2) == forOp.getBody()->getOperations().size()
+                   ? success()
+                   : failure();
+      }
+    }
+  }
+  return failure();
 }
 
 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
@@ -1262,9 +1322,11 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
   auto forOp = dyn_cast<scf::ForOp>(containingOp);
   if (!forOp)
     return failure();
-  Value resultingValue = forOp->getResult(resultNumber);
+  LoopLikeOpInterface topLevelForOp =
+      getOuterNestLoopsWhile(forOp, isForOpYieldResultOfInnerLoop).front();
+  Value resultingValue = topLevelForOp->getResult(resultNumber);
 
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
+  return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
 }
 
 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1289,28 +1351,108 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
   return getConsumerFromUses(resultingValue, containingOp->getBlock());
 }
 
-/// This utility currently checks whether the loop either :-
-/// 1. Yields exactly one result.
-/// 2. Has consumer op as its first user and other users to be in the same
-/// containing block as that of consumer op's. Currently we clone the loop op
-/// right before the consumer op in order to maintain a valid def-use chain.
-/// This utility thus helps ensuring that no invalid IR is formed due to the
-/// same.
+/// This utility currently checks whether the first userOp of loop is NOT before
+/// the last defineOp of consumer. Currently we clone the loop op right before
+/// a certain op in order to maintain a valid def-use chain. This utility thus
+/// helps ensuring that no invalid IR is formed due to the same. E.g.
+///
+/// ```
+/// %0 = scf.for() {
+///
+/// }
+/// ...
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ...
+/// %3 = consumerOp(%2)
+/// ```
+///
+/// If the `firstUserOfLoop`is before `lastDefOfConsumer`, then it would be
+/// invalid to clone the loop op right before the `firstUserOfLoop`:
+///
+/// ```
+/// %0:2 = scf.for() {
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ```
+///
+/// To address this issue, this utility would double-check there is no user of
+/// `firstUserOfLoop` before `lastDefOfConsumer`. If so, moving
+/// `firstUserOfLoop` after `lastDefOfConsumer`. Then, it turns out valid as
+/// follow:
+///
+/// ```
+/// %2 = lastDefOfConsumer
+/// %0:2 = scf.for() {
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ```
+///
+/// @param loopOp: loop operation
+/// @param consumerOp: consumer operation
+/// @param insertPointBefore: which operation we clone the looOp right before
 static LogicalResult checkAssumptionForLoop(Operation *loopOp,
-                                            Operation *consumerOp) {
-  // Check if the loop op yields one result.
-  if (loopOp->getNumResults() == 1)
-    return success();
-  // Check if the consumerOp is the first user of the loopOp and if other users
-  // are in the same containing block as that of consumer op's.
+                                            Operation *consumerOp,
+                                            Operation **insertPointBefore) {
   Block *parentBlock = consumerOp->getBlock();
+  // loopOp and consumerOp should stay in the same block.
+  if (loopOp->getBlock() != parentBlock)
+    return failure();
+
+  Operation *firstUserOfLoop = consumerOp, *lastDefOfConsumer = loopOp;
+  // Find the first user of loopOp
   for (Operation *userOp : loopOp->getUsers()) {
     if (userOp == consumerOp)
       continue;
-    if (parentBlock != userOp->getBlock() ||
-        !consumerOp->isBeforeInBlock(userOp))
+    // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
+    // block with any other types of operation. Thus, just redirecting to its
+    // parent `InParallelOp`.
+    if (isa<tensor::ParallelInsertSliceOp>(userOp)) {
+      userOp = userOp->getParentOfType<scf::InParallelOp>();
+    }
+    if (parentBlock != userOp->getBlock()) {
+      return failure();
+    }
+    if (userOp->isBeforeInBlock(firstUserOfLoop)) {
+      firstUserOfLoop = userOp;
+    }
+  }
+  // Find the last define of consumer
+  for (Value operand : consumerOp->getOperands()) {
+    // If the operand is `BlockArgument`, auto skip.
+    if (isa<BlockArgument>(operand))
+      continue;
+    auto defineOp = operand.getDefiningOp();
+    if (defineOp == loopOp)
+      continue;
+    if (!defineOp || parentBlock != defineOp->getBlock()) {
+      return failure();
+    }
+    if (lastDefOfConsumer->isBeforeInBlock(defineOp)) {
+      lastDefOfConsumer = defineOp;
+    }
+  }
+  if (firstUserOfLoop->isBeforeInBlock(lastDefOfConsumer)) {
+    // Try to move if possible
+    if (std::all_of(firstUserOfLoop->getUsers().begin(),
+                    firstUserOfLoop->getUsers().end(),
+                    [&lastDefOfConsumer, &parentBlock](Operation *userOp) {
+                      return userOp->getBlock() == parentBlock &&
+                             lastDefOfConsumer->isBeforeInBlock(userOp);
+                    })) {
+      // Safely moving
+      firstUserOfLoop->moveAfter(lastDefOfConsumer);
+    } else {
       return failure();
+    }
   }
+  // Set InsertPoint
+  *insertPointBefore = firstUserOfLoop;
   return success();
 }
 
@@ -1382,9 +1524,9 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
 
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
-FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
-                                      Operation *candidateSliceOp) {
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
+                               Operation *candidateSliceOp) {
   if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
           candidateSliceOp))
     return failure();
@@ -1418,45 +1560,83 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   if (isInsertSliceOp) {
     auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
     oldLoopOp = forOp;
-    llvm::append_range(newOuts, forOp.getInits());
-    oldLoopBody = forOp.getBody();
     initSize = forOp.getInits().size();
   } else {
     auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
     oldLoopOp = forallOp;
-    llvm::append_range(newOuts, forallOp.getOutputs());
-    oldLoopBody = forallOp.getBody();
     initSize = forallOp.getOutputs().size();
     rank = forallOp.getRank();
   }
 
-  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+  Operation *oldTopLevelLoop = oldLoopOp;
+  SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
+  if (isInsertSliceOp) {
+    oldNestedForOps =
+        getOuterNestLoopsWhile(cast<LoopLikeOpInterface>(oldTopLevelLoop),
+                               isForOpYieldResultOfInnerLoop);
+    oldTopLevelLoop = oldNestedForOps.front();
+  }
+  // 2.a Check assumption for loop and find suitable insertPoint that loop
+  // structure would be cloned right before.
+  Operation *insertPointBefore = nullptr;
+  if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp,
+                                    &insertPointBefore))) {
     return rewriter.notifyMatchFailure(
-        oldLoopOp, "containing loop op should either yield just one value or "
-                   "have the consumer op as its first user");
+        oldTopLevelLoop, "containing loop op does not satisfy the assumption "
+                         "and no suitable insertPoint is found");
   }
 
   OpBuilder::InsertionGuard g(rewriter);
 
-  // 2. Check consumer is not using scf loop's output as init.
+  // 2.b Check consumer is not using scf loop's output as init.
   auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  newOuts.append(dpsInits);
+  SmallVector<Value> newInitAppend = dpsInits;
 
   Location loc = oldLoopOp->getLoc();
 
   // 3. Create new scf loop op.
-  rewriter.setInsertionPoint(consumerOp);
+  rewriter.setInsertionPoint(insertPointBefore);
+
+  // 3.a Create new outer scf loops if necessary
+  bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size() > 1;
+  if (isNestedForOps) {
+    for (auto &&[index, loopOp] :
+         llvm::enumerate(MutableArrayRef(oldNestedForOps).drop_back())) {
+      auto forOp = cast<scf::ForOp>(loopOp);
+      SmallVector<Value> newInits;
+      newInits = llvm::to_vector(forOp.getInits());
+      newInits.append(newInitAppend.begin(), newInitAppend.end());
+      auto newLoop = rewriter.create<scf::ForOp>(
+          forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+          forOp.getStep(), newInits);
+      newInitAppend = llvm::map_to_vector(
+          newLoop.getRegionIterArgs().take_back(newInitAppend.size()),
+          [](BlockArgument bArg) -> Value { return bArg; });
+      rewriter.mergeBlocks(
+          forOp.getBody(), newLoop.getBody(),
+          newLoop.getBody()->getArguments().take_front(initSize + 1));
+      rewriter.replaceOp(
+          forOp, newLoop->getResults().take_front(forOp->getNumResults()));
+      newNestedForOps.push_back(newLoop);
+      rewriter.setInsertionPointAfter(oldNestedForOps[index + 1]);
+    }
+  }
+
+  // 3.b Create new inner most scf loop
   Operation *newLoopOp = nullptr;
   Block *newLoopBody = nullptr;
   if (isInsertSliceOp) {
     auto forOp = cast<scf::ForOp>(oldLoopOp);
+    llvm::append_range(newOuts, forOp.getInits());
+    newOuts.append(newInitAppend);
+    oldLoopBody = forOp.getBody();
     auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
                                                 forOp.getUpperBound(),
                                                 forOp.getStep(), newOuts);
@@ -1464,6 +1644,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
     newLoopBody = newForOp.getBody();
   } else {
     auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+    llvm::append_range(newOuts, forallOp.getOutputs());
+    newOuts.append(newInitAppend);
+    oldLoopBody = forallOp.getBody();
     auto newForallOp = rewriter.create<scf::ForallOp>(
         loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
         forallOp.getMixedStep(), newOuts, forallOp.getMapping());
@@ -1577,28 +1760,159 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         newForallOp.getBody()->getArguments().drop_front(rank + initSize));
   }
 
-  // 12. Replace the result of scf loop and consumer op with new loop's results.
+  // 12.  Restore outer loops from inner to outer
+  if (isNestedForOps) {
+    newNestedForOps.push_back(cast<scf::ForOp>(newLoopOp));
+    for (auto [outerLoop, innerLoop] :
+         llvm::zip_equal(MutableArrayRef(newNestedForOps).drop_back(),
+                         MutableArrayRef(newNestedForOps).drop_front())) {
+      auto forOp = cast<scf::ForOp>(outerLoop);
+      auto outerLoopYield =
+          cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+      SmallVector<Value> newYields =
+          llvm::to_vector(outerLoopYield.getOperands());
+      ValueRange additionalYields =
+          innerLoop->getResults().take_back(newInitAppend.size());
+      newYields.append(additionalYields.begin(), additionalYields.end());
+      rewriter.setInsertionPoint(outerLoopYield);
+      rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
+    }
+  }
+
+  // 13. Replace the result of scf loop and consumer op with new loop's results.
   for (auto &&[oldResult, newResult] :
        llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
+  Operation *newTopLevelLoop =
+      isNestedForOps ? newNestedForOps.front() : newLoopOp;
   for (auto &&[oldResult, newResult] :
        llvm::zip(consumerOp->getResults(),
-                 newLoopOp->getResults().drop_front(initSize))) {
+                 newTopLevelLoop->getResults().drop_front(initSize))) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
-  // 13. Need to erase the old scf loop and the cloned consumer op.
+  // 14. Need to erase the old scf loop and the cloned consumer op.
   rewriter.eraseOp(oldLoopOp);
   rewriter.eraseOp(clonedConsumerOp);
 
+  // 15. Need to erase the cloned insertSliceOp and unused extractSliceOp in
+  // avoid of complex domination analysis
+  assert(clonedInsertSliceOp->hasOneUse());
+  auto unUsedExtractOp =
+      cast<tensor::ExtractSliceOp>((*clonedInsertSliceOp->getUsers().begin()));
+  rewriter.eraseOp(unUsedExtractOp);
+  rewriter.eraseOp(clonedInsertSliceOp);
+
   return scf::SCFFuseConsumerOfSliceResult{
       consumerOpOperand,
       &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
       tileAndFuseResult->tiledOps};
 }
 
+/// Get the real consumers from candidate InsertSliceOp. E.g
+///
+/// ```
+/// %1 = scf.for
+///  %2 = scf.for
+///   %3 = scf.for
+///      ...
+///      %4 = insert
+///      yield %4
+///   %5 = insert %3
+///   yield %5
+///  yield %2
+/// %6 = consumerOp ins(%1)
+/// ```
+///
+/// @param candidateSliceOp: %4 = insert
+/// @param forwardSlice: in-out parameter populated by forward insertSliceOps
+/// @return OpOperand consumers: %6 = consumerOp ins(%1)
+static FailureOr<SmallVector<OpOperand *>> getRealConsumersFromInsertSliceOp(
+    Operation *candidateSliceOp,
+    SmallVector<OffsetSizeAndStrideOpInterface> &forwardSlice, int curDepth = 0,
+    int maxDepth = 5) {
+  assert(isa<OffsetSizeAndStrideOpInterface>(candidateSliceOp));
+  // Control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  forwardSlice.push_back(
+      cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp));
+  Value resultOfLoop;
+  if (auto sliceOp =
+          dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+    Value destValue = sliceOp.getDest();
+    auto iterArg = cast<BlockArgument>(destValue);
+    auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+    if (!forallOp)
+      return failure();
+    resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+    Value resultValue = sliceOp.getResult();
+    for (auto &useOperand : resultValue.getUses()) {
+      if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        if (llvm::detail::isPresent(resultOfLoop))
+          return failure();
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+      }
+    }
+  }
+
+  if (!llvm::detail::isPresent(resultOfLoop))
+    return failure();
+
+  bool traverseUpperLoop;
+  do {
+    traverseUpperLoop = false;
+    for (OpOperand &useOperand : resultOfLoop.getUses()) {
+      if (auto sliceOp =
+              dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+        return getRealConsumersFromInsertSliceOp(sliceOp, forwardSlice,
+                                                 curDepth + 1);
+      } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        // walk through outer loop
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+        traverseUpperLoop = true;
+        break;
+      }
+    }
+  } while (traverseUpperLoop);
+  // Return all operands using result of top level loop.
+  return llvm::map_to_vector(resultOfLoop.getUses(),
+                             [](OpOperand &u) -> OpOperand * { return &u; });
+}
+
+/// Fusing real consumer of a single slice even within complex nested loops via
+/// multiple application of `tileAndFuseConsumerOfSliceImpl`.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  SmallVector<OffsetSizeAndStrideOpInterface> forwardSlice;
+  if (failed(getRealConsumersFromInsertSliceOp(candidateSliceOp, forwardSlice)))
+    return failure();
+
+  FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResult;
+  // reverse from outer to inner
+  std::reverse(forwardSlice.begin(), forwardSlice.end());
+  // multiple application of `tileAndFuseConsumerOfSliceImpl`
+  for (auto &sliceOp : forwardSlice) {
+    fuseConsumerResult = tileAndFuseConsumerOfSliceImpl(rewriter, sliceOp);
+    if (failed(fuseConsumerResult)) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "could not fuse consumer of sliceOp");
+    }
+  }
+  return fuseConsumerResult;
+}
+
 //===----------------------------------------------------------------------===//
 // lowerToLoopsUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 400b558e37fcd..8022d5af37078 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -45,12 +45,12 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
 //      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
 // CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
-//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
 //      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
 //      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
 //      CHECK:      %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
 // CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
 // CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
 //      CHECK:      %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
 //      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
 //      CHECK:   }
@@ -170,13 +170,13 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
 //      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
 // CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
-//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
 //      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
 //      CHECK:      %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
 //      CHECK:      %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
 //      CHECK:      %[[ELEM_OUT:.*]]:2 = linalg.generic
 // CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
 // CHECK-SAME:              outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
 //      CHECK:      %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
 //      CHECK:      %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
 //      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :
@@ -315,3 +315,97 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+  func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+      %iv0 = affine.apply #map(%arg3)
+      %iv1 = affine.apply #map(%arg4)
+      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+          scf.yield %insert_slice : tensor<128x128xf32>
+        }
+        scf.yield %3 : tensor<128x128xf32>
+      }
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+      }
+    }
+    %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    return %5 : tensor<256x256xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+//      CHECK: #[[MAP0:.*]] =  affine_map<(d0) -> (d0 * 128)>
+//      CHECK: func.func @fuse_tilable_consumer_nested_scf_loop(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[dest1:.*]] = linalg.fill
+// CHECK-SAME:          outs(%[[dest0]] :
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG0:.*]] = %[[dest1]], %[[SECOND_OUT_ARG0:.*]] = %[[dest0]])
+// CHECK-SAME:   {
+//      CHECK:      %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
+//      CHECK:      %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
+//      CHECK:      %[[MAT_OUT_SLICE0:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
+//      CHECK:      %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
+//      CHECK:      %[[ADD_OPERAND2_SLICE0:.*]] = tensor.extract_slice %[[ARG2]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[ADD_OUT_SLICE0:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV3:.*]] = %[[C0]]
+// CHECK-SAME:          iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[MAT_OUT_SLICE0]], %[[SECOND_OUT_ARG1:.*]] = %[[ADD_OUT_SLICE0]])
+// CHECK-SAME:      {
+//      CHECK:          %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV4:.*]] = %[[C0]]
+// CHECK-SAME:            iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
+// CHECK-SAME:          {
+//      CHECK:            %[[MAT_OUT_SLICE1:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
+//      CHECK:            %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
+//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:                  outs(%[[MAT_OUT_SLICE1]] :
+//      CHECK:            %[[ADD_OPERAND2_SLICE1:.*]] = tensor.extract_slice %[[ADD_OPERAND2_SLICE0]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[ADD_OUT_SLICE1:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE1]] :
+// CHECK-SAME:              outs(%[[ADD_OUT_SLICE1]] :
+//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
+//      CHECK:          }
+//      CHECK:          scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
+//      CHECK:      }
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]]#1 into %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]]#0 into %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#1 :
\ No newline at end of file



More information about the Mlir-commits mailing list