[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 3 00:43:08 PDT 2024


https://github.com/Yun-Fly created https://github.com/llvm/llvm-project/pull/94190

Hi, based on early discussion in [this thread](https://github.com/llvm/llvm-project/pull/88712#discussion_r1590568717). This patch aims to extend new feature of fusing consumer to more complex nested loop structure. E.g.

```
#map = affine_map<(d0) -> (d0 * 128)>
module {
  func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %c128 = arith.constant 128 : index
    %cst = arith.constant 0.000000e+00 : f32
    %dest0 = tensor.empty() : tensor<256x256xf32>
    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
      %iv0 = affine.apply #map(%arg3)
      %iv1 = affine.apply #map(%arg4)
      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
          scf.yield %insert_slice : tensor<128x128xf32>
        }
        scf.yield %3 : tensor<128x128xf32>
      }
      scf.forall.in_parallel {
         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
      }
    }
    %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
    return %5 : tensor<256x256xf32>
  }
}
```

#### What's New in this PR:
1. support nested loop structure, including both `scf.for` and `scf.forall`.
2. support multi-level `insert_slice` or `parallel_insert_slice`.


NOTE that: this PR DOES NOT deal with the refactor of `getTiledImplementation` we have talked before but just focuses on the functionality enhancement, BTW, in above example, you can also find that the similar issue related to unmatched semantic between tiled operand and assumption of current `getTiledImplementation` even on `dpsInits`. To unblock this necessary patch, I temporarily follow the method as @MaheshRavishankar suggested, using dummy `insert_slice` to align those gap.

The resulting IR will finally appear like below:
```
#map = affine_map<(d0) -> (d0 * 128)>
#map1 = affine_map<(d0, d1) -> (d0 + d1 * 128)>
module {
  module {
    func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
      %c0 = arith.constant 0 : index
      %c64 = arith.constant 64 : index
      %c128 = arith.constant 128 : index
      %cst = arith.constant 0.000000e+00 : f32
      %0 = tensor.empty() : tensor<256x256xf32>
      %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
      %2:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %1, %arg6 = %0) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
        %3 = affine.apply #map(%arg3)
        %4 = affine.apply #map(%arg4)
        %extracted_slice = tensor.extract_slice %arg5[%3, %4] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
        %extracted_slice_0 = tensor.extract_slice %arg0[%3, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
        %extracted_slice_1 = tensor.extract_slice %arg1[0, %4] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
        %extracted_slice_2 = tensor.extract_slice %arg6[%3, %4] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
        %5:2 = scf.for %arg7 = %c0 to %c128 step %c64 iter_args(%arg8 = %extracted_slice, %arg9 = %extracted_slice_2) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
          %6:2 = scf.for %arg10 = %c0 to %c128 step %c64 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
            %extracted_slice_3 = tensor.extract_slice %arg11[%arg7, %arg10] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
            %extracted_slice_4 = tensor.extract_slice %extracted_slice_0[%arg7, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
            %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[0, %arg10] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
            %7 = linalg.matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
            %8 = affine.apply #map1(%arg7, %arg3)
            %9 = affine.apply #map1(%arg10, %arg4)
            %extracted_slice_6 = tensor.extract_slice %arg2[%8, %9] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
            %extracted_slice_7 = tensor.extract_slice %arg12[%arg7, %arg10] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
            %10 = linalg.add ins(%7, %extracted_slice_6 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%extracted_slice_7 : tensor<64x64xf32>) -> tensor<64x64xf32>
            %inserted_slice = tensor.insert_slice %7 into %arg11[%arg7, %arg10] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
            %inserted_slice_8 = tensor.insert_slice %10 into %arg12[%arg7, %arg10] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
            scf.yield %inserted_slice, %inserted_slice_8 : tensor<128x128xf32>, tensor<128x128xf32>
          }
          scf.yield %6#0, %6#1 : tensor<128x128xf32>, tensor<128x128xf32>
        }
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
          tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
        }
      }
      return %2#1 : tensor<256x256xf32>
    }
  }
}
```

Looking forward to your suggestion and review, thanks.

>From 2001ce0915ebc08e173ddab6e10251fab55c8160 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Sun, 2 Jun 2024 23:32:22 -0700
Subject: [PATCH] extend consumer fuse to nested scf loop

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 736 ++++++++++++------
 .../tile-and-fuse-consumer.mlir               |  96 +++
 2 files changed, 607 insertions(+), 225 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..9dd730e64a030 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1103,98 +1104,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
 // tileAndFuseConsumerUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
-/// A utility function that checks whether the only use of the result of a
-/// tensor.insert_slice op is in a scf.yield op.
-static LogicalResult
-checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
-  Value result = candidateSliceOp.getResult();
-  Value::use_range uses = result.getUses();
-  if (!llvm::hasSingleElement(uses)) {
-    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
-    return failure();
-  }
-  OpOperand &operandUse = (*uses.begin());
-  Operation *userOp = operandUse.getOwner();
-  if (!isa<scf::YieldOp>(userOp)) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "Expected scf.yield to be the only user, but got -> "
-               << (*userOp));
-    return failure();
-  }
-  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
-    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
-                               "be in the same block\n");
-    return failure();
-  }
-  return success();
-}
-
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
-                                                  Block *containingOpBlock) {
-  // Step 1. Check that the value has exactly one use.
-  if (!llvm::hasSingleElement(val.getUses()))
-    return failure();
-  // Step 2. Get uses.
-  OpOperand &operand = (*val.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.for, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp))
-    return failure();
-  if (containingOpBlock != consumerOp->getBlock())
-    return failure();
-  return &operand;
-}
-
-/// Fetch the untiled consumer of a scf.for's result which is yielded by a
-/// tensor.insert_slice. This function makes the following assumptions :
-/// 1.  tensor.insert_slice has scf.yield as its only user.
-/// 2.  scf.for's corresponding result has only one use.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
-  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
-    return failure();
-  Value sliceResult = candidateSliceOp.getResult();
-  // Step 1. Fetch the corresponding output.
-  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
-  unsigned resultNumber = yieldOpOperand.getOperandNumber();
-  // Step 2. Check containing op is scf.for.
-  Operation *containingOp = candidateSliceOp->getParentOp();
-  auto forOp = dyn_cast<scf::ForOp>(containingOp);
-  if (!forOp)
-    return failure();
-  Value resultingValue = forOp->getResult(resultNumber);
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
-  // Step 1. Fetch the corresponding output
-  Value sliceDest = candidateSliceOp.getDest();
-  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
-  if (!iterArg)
-    return failure();
-  Operation *containingOp = iterArg.getOwner()->getParentOp();
-  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
-    return failure();
-  // Step 2. Check that the containing op is scf.forall.
-  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
-  if (!forallOp)
-    return failure();
-  Value resultingValue =
-      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
 /// This utility currently checks whether the loop either :-
 /// 1. Yields exactly one result.
 /// 2. Has consumer op as its first user and other users to be in the same
@@ -1220,31 +1129,116 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
   return success();
 }
 
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
-  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(insertSlice);
-  } else if (auto parallelInsertSlice =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(parallelInsertSlice);
-  } else {
+// Traverse and collect all outer loops of given sliceOp, sorted by
+// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+    OffsetSizeAndStrideOpInterface sliceOp,
+    std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+  SmallVector<LoopLikeOpInterface> outerLoops;
+  auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+  while (forOp) {
+    outerLoops.push_back(forOp);
+    if (untilLoop.has_value() && *untilLoop == forOp)
+      break;
+    forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+  }
+  return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+// Get the Result of top-level Loop which yield the target InsertSliceOp. E.g
+// ```
+// %1 = scf.for
+//  %2 = scf.for
+//   %3 = scf.for
+//      ...
+//      %4 = insert
+//      yield %4
+//   %5 = insert %3
+//   yield %5
+//  yield %2
+// ```
+// @param targetSliceOp: %4 = insert
+// @return Result Value: %1
+//         Collected insertSliceOp List during walk including targetSliceOp:
+//                %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(
+    OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0,
+    int maxDepth = 5) {
+  // control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+  candidateSliceOpList.push_back(targetSliceOp);
+  Value resultOfLoop;
+  if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(
+          targetSliceOp.getOperation())) {
+    Value destValue = sliceOp.getDest();
+    auto iterArg = cast<BlockArgument>(destValue);
+    auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+    if (!forallOp)
+      return failure();
+    resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(
+                 targetSliceOp.getOperation())) {
+    Value resultValue = sliceOp.getResult();
+    for (auto &useOperand : resultValue.getUses()) {
+      if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        if (llvm::detail::isPresent(resultOfLoop))
+          return failure();
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+      }
+    }
+  }
+
+  if (!llvm::detail::isPresent(resultOfLoop))
     return failure();
+
+  while (true) {
+    bool walkThroughOuterLoop = false;
+    for (auto &useOperand : resultOfLoop.getUses()) {
+      if (auto sliceOp =
+              dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+        auto resultAndSliceOpsPair =
+            getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp, curDepth + 1);
+        if (failed(resultAndSliceOpsPair))
+          return failure();
+        candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
+                                    (*resultAndSliceOpsPair).second.end());
+        return std::make_pair((*resultAndSliceOpsPair).first,
+                              candidateSliceOpList);
+      } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        // walk through outer loop
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+        walkThroughOuterLoop = true;
+        break;
+      }
+    }
+    if (!walkThroughOuterLoop)
+      break;
   }
+  return std::make_pair(resultOfLoop, candidateSliceOpList);
 }
 
 /// After fusing consumer into scf.for we want to modify the scf.yield operation
 /// to reflect the same by returning the values yielded by the tiled consumer.
 static void
 fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
-                      TilingResult &tilingResult,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                      ResultRange tilingResult,
+                      SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+                      SmallVector<SmallVector<OpFoldResult>> &resultSizes,
                       ArrayRef<BlockArgument> bbArgs) {
   scf::YieldOp oldTerminatorOp =
       cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
   unsigned totalOldResults = oldTerminatorOp->getNumResults();
-  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+  unsigned totalTiledResults = tilingResult.size();
   SmallVector<Value> newYieldOperands;
   newYieldOperands.reserve(totalOldResults + totalTiledResults);
   for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1253,8 +1247,7 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
   rewriter.setInsertionPointAfter(oldTerminatorOp);
   Location loc = newForOp.getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
-                       resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1267,18 +1260,17 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
 
 /// After fusing consumer into scf.forall we want to yield each of the resulting
 /// values by the tiled consumer within scf.forall.in_parallel region.
-static void
-fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
-                           SmallVector<Value> tiledResults,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
-                           ArrayRef<BlockArgument> bbArgs) {
+static void fixTerminatorSCFInParallel(
+    RewriterBase &rewriter, scf::ForallOp newForallOp, ResultRange tilingResult,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+    ArrayRef<BlockArgument> bbArgs) {
   scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
   rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
   Location firstYieldOpLoc =
       (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1286,6 +1278,180 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
   }
 }
 
+// If the top level loop of nested loop structure is scf.forall, need to create
+// additional tensor.extract_slice for its new appended `shared_outs` in order
+// to pass correct local memory for inner loops. E.g.
+//
+// scf.forall shared_outs(%o1=..., %o2=...) {
+//     %local_o1 = extract_slice %o1
+//     // fix new appended `shared_out` %o2
+//     %local_o2 = extract_slice %o2
+//     scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+//        ...
+//     }
+//     ...
+// }
+static void
+fixSharedOutSCFForall(RewriterBase &rewriter, scf::ForallOp outerLoop,
+                      LoopLikeOpInterface innerLoop,
+                      SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+                      SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+                      unsigned newInitSize,
+                      SmallVector<tensor::ExtractSliceOp> &newExtractOps) {
+  rewriter.setInsertionPoint(innerLoop);
+  Location Loc = outerLoop.getLoc();
+  MutableArrayRef<BlockArgument> bbArgs = outerLoop.getBody()->getArguments();
+
+  SmallVector<tensor::ExtractSliceOp> newOps;
+  newOps.reserve(resultOffsets.size());
+  for (auto [bbArg, offset, sizes] : llvm::zip_equal(
+           bbArgs.take_back(newInitSize), resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        Loc, bbArg, offset, sizes, strides);
+    newOps.push_back(newExtractOp);
+  }
+  newExtractOps = newOps;
+}
+
+// If outerMost loop of nested loop structure is `scf.forall`, need to deal with
+// DpsInit of tiled consumer
+static void fixDpsInitsOfTiledConsumer(
+    RewriterBase &rewriter, Operation *tiledConsumer,
+    ArrayRef<BlockArgument> bbArgs,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes) {
+  rewriter.setInsertionPoint(tiledConsumer);
+  Location Loc = tiledConsumer->getLoc();
+  for (auto &&[bbArg, offset, sizes, dpsInit] :
+       llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+                       cast<DestinationStyleOpInterface>(tiledConsumer)
+                           .getDpsInitsMutable())) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        Loc, bbArg, offset, sizes, strides);
+    dpsInit.set(newExtractOp.getResult());
+  }
+}
+
+// compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+    RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+    OffsetSizeAndStrideOpInterface ossSliceOp,
+    SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+  // 1. check all stride all 1
+  if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+  }
+  // 2. compute iteration domain Tile from input position
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  if (failed(tilableOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
+          ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        tilableOp, "can't get iter domain position from input position");
+  }
+  unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  // 3. compute result Tile by resultNumber
+  for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
+    if (failed(tilableOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
+      return rewriter.notifyMatchFailure(
+          tilableOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+  allResultOffsets = resultOffsets;
+  allResultSizes = resultSizes;
+  return success();
+}
+
+// Considering multi-level tensor.*SliceOp maybe based on different
+// coordination, this utility computes the real OFFSET coordinated on ROOT
+// SliceOp. E.g
+//             %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
+//         %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
+//
+// where the coordination can be illustrated as follow:
+//
+//  %3 ----------------------------------
+//  |         |         |
+//  | OFFSET2 | OFFSET1 |
+//  | ------ %0         |
+//  |                   |
+//  |                   |
+//  |------------------ %1 ------ |
+//  |                   |  SIZE1  |
+//  |                   |         |
+//  |                   |         |
+//  |                   | ------- |
+//  |
+//
+// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
+static FailureOr<SmallVector<OpFoldResult>>
+computeRealOffsetsCoordinatedRootSliceOp(
+    RewriterBase &rewriter, Location loc,
+    OffsetSizeAndStrideOpInterface candidateSliceOp,
+    MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
+  if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "candidateSliceOp has stride");
+  }
+  SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
+  // real offsets equals to accumulative offsets of outer candidates
+  for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
+       iter++) {
+    // assert each outer candidate slice has no stride
+    if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
+          return !isConstantIntValue(stride, 1);
+        })) {
+      return failure();
+    }
+    for (auto &&[ofr1, ofr2] :
+         llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
+      using AVE = affine::AffineValueExpr;
+      affine::AffineBuilder ab(rewriter, loc);
+      AffineExpr dim0, dim1, sym;
+      bindDims(rewriter.getContext(), dim0, dim1);
+      bindSymbols(rewriter.getContext(), sym);
+      auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
+      ofr1 = ab.add(aveOffset1, aveOffset2);
+    }
+  }
+  return realOffsets;
+}
+
+// Get the first tilable user of given Value and check its domination at the
+// same time
+static FailureOr<OpOperand *>
+getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
+  for (auto &useOfval : val.getUses()) {
+    Operation *consumerOp = useOfval.getOwner();
+    // 1. Check whether consumerOp is tilable
+    if (!isa<TilingInterface>(consumerOp) ||
+        !isa<DestinationStyleOpInterface>(consumerOp))
+      continue;
+    // 2. check stay in same block with loopOp
+    if (loopOp->getBlock() != consumerOp->getBlock())
+      continue;
+    // 3. check no other user before it
+    if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
+      continue;
+    }
+    return &useOfval;
+  }
+  return failure();
+}
+
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1297,10 +1463,29 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
 
-  // 1. Get the consumer of scf.for for the result yielded by
-  // tensor.insert_slice/parallel_insert_slice.
+  // 1.a Get the real consumer of candidate
+  // tensor.insert_slice/parallel_insert_slice by walking through
+  // scf.for/scf.forall and collect all [Parallel]insertSliceOp(s) along the
+  // way.
+  auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
+  FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+      resultAndSliceOpsPair =
+          getResultOfTopLevelLoopYieldInsertSliceOp(ossSliceOp);
+  if (failed(resultAndSliceOpsPair)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+
+  // 1.b Get all outer loops of candidateSliceOp.
+  SmallVector<LoopLikeOpInterface> outerLoops = getOuterLoopsOfSliceOp(
+      ossSliceOp, dyn_cast<OpResult>((*resultAndSliceOpsPair).first)
+                      .getDefiningOp<LoopLikeOpInterface>());
+  LoopLikeOpInterface outerMostLoop = outerLoops.front();
+
+  // 2 Get first tilable consumer op
   FailureOr<OpOperand *> maybeConsumerOpOperand =
-      getUntiledConsumerFromSlice(candidateSliceOp);
+      getTilableConsumerOperandFirstUseVal((*resultAndSliceOpsPair).first,
+                                           outerMostLoop);
   if (failed(maybeConsumerOpOperand)) {
     return rewriter.notifyMatchFailure(candidateSliceOp,
                                        "could not fetch consumer to fuse");
@@ -1316,111 +1501,187 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
   }
 
-  Operation *oldLoopOp = nullptr;
-  SmallVector<Value> newOuts;
-  Block *oldLoopBody = nullptr;
-  unsigned initSize = 0;
-  unsigned rank = 1;
-  if (isInsertSliceOp) {
-    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
-    oldLoopOp = forOp;
-    llvm::append_range(newOuts, forOp.getInits());
-    oldLoopBody = forOp.getBody();
-    initSize = forOp.getInits().size();
-  } else {
-    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
-    oldLoopOp = forallOp;
-    llvm::append_range(newOuts, forallOp.getOutputs());
-    oldLoopBody = forallOp.getBody();
-    initSize = forallOp.getOutputs().size();
-    rank = forallOp.getRank();
-  }
-
-  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
-    return rewriter.notifyMatchFailure(
-        oldLoopOp, "containing loop op should either yield just one value or "
-                   "have the consumer op as its first user");
-  }
-
-  OpBuilder::InsertionGuard g(rewriter);
-
-  // 2. Check consumer is not using scf loop's output as init.
+  // 3. Check consumer is not using outerMostLoop's output as init.
   auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  newOuts.append(dpsInits);
+  ValueRange newInitAppend = dpsInits;
 
-  Location loc = oldLoopOp->getLoc();
+  // 4. reconstruct nested loop from outer to inner.
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList =
+      (*resultAndSliceOpsPair).second;
+  SmallVector<LoopLikeOpInterface> newOuterLoops;
+  SmallVector<SmallVector<OpFoldResult>> allResultOffsets, allResultSizes;
+  // extract slice from newInits of outer-most scf.forall
+  SmallVector<tensor::ExtractSliceOp> newExtractOps;
 
-  // 3. Create new scf loop op.
-  rewriter.setInsertionPoint(consumerOp);
-  Operation *newLoopOp = nullptr;
+  Block *oldLoopBody = nullptr;
   Block *newLoopBody = nullptr;
+  SmallVector<Value> newOuts;
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // set insertPoint right before consumerOp
+  rewriter.setInsertionPoint(consumerOp);
+
+  for (auto [index, loop] :
+       llvm::enumerate(MutableArrayRef(outerLoops).drop_back())) {
+    if (index > 0)
+      rewriter.setInsertionPoint(loop);
+
+    LoopLikeOpInterface newLoopOp;
+    // Create a new loop with the new init values for this loop.
+    if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
+      newOuts = llvm::to_vector(forOp.getInits());
+      newOuts.append(newInitAppend.begin(), newInitAppend.end());
+      auto newLoop = rewriter.create<scf::ForOp>(
+          forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+          forOp.getStep(), newOuts);
+      newLoopOp = newLoop;
+      oldLoopBody = forOp.getBody();
+      newLoopBody = newLoop.getBody();
+      newInitAppend =
+          newLoopBody->getArguments().take_back(newInitAppend.size());
+    } else if (auto forallOp = dyn_cast<scf::ForallOp>(loop.getOperation())) {
+      newOuts = llvm::to_vector(forallOp.getOutputs());
+      newOuts.append(newInitAppend.begin(), newInitAppend.end());
+      auto newLoop = rewriter.create<scf::ForallOp>(
+          forallOp.getLoc(), forallOp.getMixedLowerBound(),
+          forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+          forallOp.getMapping());
+      rewriter.eraseOp(newLoop.getTerminator());
+      newLoopOp = newLoop;
+      oldLoopBody = forallOp.getBody();
+      newLoopBody = newLoop.getBody();
+
+      // create extractSliceOp for newInits
+      assert(index == 0 && "Currently Only outerMostLoop assumed ForallOp");
+      auto outerMostCandidate = candidateSliceOpList.back();
+      assert(isa<tensor::ParallelInsertSliceOp>(outerMostCandidate));
+      // set InsertPoint before next inner loop
+      auto nextLoop = outerLoops[index + 1];
+      rewriter.setInsertionPoint(nextLoop);
+      if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+              rewriter, cast<TilingInterface>(consumerOp), operandNumber,
+              outerMostCandidate, allResultOffsets, allResultSizes))) {
+        return failure();
+      }
+      fixSharedOutSCFForall(rewriter, newLoop, nextLoop, allResultOffsets,
+                            allResultSizes, newInitAppend.size(),
+                            newExtractOps);
+      newInitAppend = llvm::map_to_vector(
+          newExtractOps,
+          [](tensor::ExtractSliceOp op) -> Value { return op.getResult(); });
+    }
+    rewriter.mergeBlocks(
+        oldLoopBody, newLoopBody,
+        newLoopBody->getArguments().take_front(oldLoopBody->getNumArguments()));
+    rewriter.replaceOp(
+        loop, newLoopOp->getResults().take_front(loop->getNumResults()));
+    newOuterLoops.push_back(newLoopOp);
+  }
+
+  // 5.a reconstruct inner-most loop.
+  LoopLikeOpInterface oldInnerMostLoop = outerLoops.back(), newInnerMostLoop;
+  Location loc = oldInnerMostLoop->getLoc();
+  if (outerLoops.size() > 1)
+    rewriter.setInsertionPoint(oldInnerMostLoop);
+
   if (isInsertSliceOp) {
-    auto forOp = cast<scf::ForOp>(oldLoopOp);
+    auto forOp = cast<scf::ForOp>(oldInnerMostLoop.getOperation());
+    newOuts = llvm::to_vector(forOp.getInits());
+    newOuts.append(newInitAppend.begin(), newInitAppend.end());
+    oldLoopBody = forOp.getBody();
     auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
                                                 forOp.getUpperBound(),
                                                 forOp.getStep(), newOuts);
-    newLoopOp = newForOp;
+    newInnerMostLoop = newForOp;
     newLoopBody = newForOp.getBody();
   } else {
-    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+    auto forallOp = cast<scf::ForallOp>(oldInnerMostLoop.getOperation());
+    newOuts = llvm::to_vector(forallOp.getOutputs());
+    newOuts.append(newInitAppend.begin(), newInitAppend.end());
+    oldLoopBody = forallOp.getBody();
     auto newForallOp = rewriter.create<scf::ForallOp>(
         loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
         forallOp.getMixedStep(), newOuts, forallOp.getMapping());
-    newLoopOp = newForallOp;
+    newInnerMostLoop = newForallOp;
     rewriter.eraseOp(newForallOp.getTerminator());
     newLoopBody = newForallOp.getBody();
   }
 
-  // 4. Move the loop body to the new op.
+  // 5.b Move the loop body to the new op.
   unsigned oldNumArguments = oldLoopBody->getNumArguments();
   rewriter.mergeBlocks(oldLoopBody, newLoopBody,
                        newLoopBody->getArguments().take_front(oldNumArguments));
+  // 5.c replace the result of old oldInnerMostLoop with newInnerMostLoop's
+  // results.
+  rewriter.replaceOp(oldInnerMostLoop,
+                     newInnerMostLoop->getResults().take_front(
+                         oldInnerMostLoop->getNumResults()));
 
-  // 5. Set insertion point before terminator op of the loop and create a new
+  // 6. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
   tensor::InsertSliceOp clonedInsertSliceOp;
   if (auto sliceOp =
           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
-    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
     rewriter.setInsertionPoint(newForallOp.getTerminator());
-    clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
   } else {
     rewriter.setInsertionPoint(candidateSliceOp);
-    clonedInsertSliceOp =
-        cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
   }
-
-  // 6.a. Clone consumer op.
-  auto newForOpBlockArgsForConsumerDest =
-      newLoopBody->getArguments().drop_front(oldNumArguments);
+  FailureOr<SmallVector<OpFoldResult>> realOffsets =
+      computeRealOffsetsCoordinatedRootSliceOp(rewriter, loc, ossSliceOp,
+                                               candidateSliceOpList);
+  if (failed(realOffsets))
+    return failure();
+  // create dummy insertSliceOp to align with the requirement of current
+  // Tiling interface and fix potential semantic mismatch with later
+  // extractSliceOp generated by `getTiledImplementation`.
+  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, candidateSliceOp->getOperand(0),
+      candidateSliceOpList.back()->getOperand(1), *realOffsets,
+      ossSliceOp.getMixedSizes(), ossSliceOp.getMixedStrides());
+
+  // 7.a. Clone consumer op.
+  SmallVector<Value> newDpsInitsForConsumerDest = llvm::map_to_vector(
+      newLoopBody->getArguments().drop_front(oldNumArguments),
+      [](BlockArgument bArg) -> Value { return bArg; });
+  ;
+  // align dps inits if necessary
+  if (!newExtractOps.empty()) {
+    for (auto &&[extractOp, newDpsInit] :
+         llvm::zip_equal(newExtractOps, newDpsInitsForConsumerDest)) {
+      auto alignDpsInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+          loc, newDpsInit, extractOp.getSource(), extractOp.getMixedOffsets(),
+          extractOp.getMixedSizes(), extractOp.getMixedStrides());
+      newDpsInit = alignDpsInsertSliceOp.getResult();
+    }
+  }
   auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
-      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+      rewriter, consumerOp, newDpsInitsForConsumerDest));
 
-  // 6.b. Replace all uses of the loop result with the result of the cloned
+  // 7.b. Replace all uses of the loop result with the result of the cloned
   // tensor.insert_slice.
   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
     operandToReplace.set(clonedInsertSliceOp.getResult());
   });
 
-  // 7 - Perform tiling of the cloned consumer and replace the operand at
+  // 8. Perform tiling of the cloned consumer and replace the operand at
   // `operandNumber` with the source of the cloned tensor.insert_slice op.
-  auto ossSliceOp =
-      cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
   FailureOr<TilingResult> tileAndFuseResult =
       tensor::replaceInsertSliceWithTiledConsumer(
-          rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(
+              clonedInsertSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
   if (failed(tileAndFuseResult)) {
     return failure();
   }
@@ -1428,75 +1689,100 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
       tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
       clonedInsertSliceOp.getSource());
 
-  // 8 - Extract offset/sizes/strides required to create the
-  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
-  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
-  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
-  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
-  // 9. Check all insert stride is 1.
-  if (llvm::any_of(strides, [](OpFoldResult stride) {
-        return !isConstantIntValue(stride, 1);
-      })) {
-    return rewriter.notifyMatchFailure(
-        candidateSliceOp, "containingOp's result yield with stride");
-  }
-
-  // 10. Try to get iter domain position from input position.
-  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-          iterDomainSizes))) {
-    return rewriter.notifyMatchFailure(
-        clonedConsumerOp, "can't get iter domain position from input position");
-  }
-
-  // 11. Try to fetch the offset and size for all results of the cloned
+  // 9. Try to fetch the offset and size for all results of the cloned
   // consumer. This would then be used to form the corresponding
   // tensor.insert_slice/parallel_insert_slice later.
-  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
-  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-      totalNumResultsOfConsumer);
-  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
-  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
-    if (failed(clonedConsumerOp.getResultTilePosition(
-            rewriter, idx, iterDomainOffsets, iterDomainSizes,
-            resultOffsets[idx], resultSizes[idx]))) {
-      return rewriter.notifyMatchFailure(
-          clonedConsumerOp,
-          "can't get result domain position from iter domain position");
-    }
+  if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+          rewriter, clonedConsumerOp, operandNumber, ossSliceOp,
+          allResultOffsets, allResultSizes))) {
+    return failure();
+  }
+
+  if (!newExtractOps.empty()) {
+    fixDpsInitsOfTiledConsumer(
+        rewriter, tileAndFuseResult->tiledOps[0],
+        newLoopBody->getArguments().drop_front(oldNumArguments),
+        allResultOffsets, allResultSizes);
   }
 
-  auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
-  auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
   if (isInsertSliceOp) {
-    auto newForOp = cast<scf::ForOp>(newLoopOp);
+    auto newForOp = cast<scf::ForOp>(newInnerMostLoop);
     fixTerminatorSCFYield(
-        rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
-        newForOp.getBody()->getArguments().drop_front(1 + initSize));
+        rewriter, newForOp, tileAndFuseResult->tiledOps[0]->getResults(),
+        allResultOffsets, allResultSizes,
+        newForOp.getBody()->getArguments().take_back(newInitAppend.size()));
   } else {
-    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
     fixTerminatorSCFInParallel(
         rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
-        arrayRefOffsets, arrayRefSizes,
-        newForallOp.getBody()->getArguments().drop_front(rank + initSize));
+        allResultOffsets, allResultSizes,
+        newForallOp.getBody()->getArguments().take_back(newInitAppend.size()));
   }
 
-  // 12. Replace the result of scf loop and consumer op with new loop's results.
-  for (auto &&[oldResult, newResult] :
-       llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
-    rewriter.replaceAllUsesWith(oldResult, newResult);
+  newOuterLoops.push_back(cast<LoopLikeOpInterface>(newInnerMostLoop));
+
+  // 10. reconstruct terminator of outer loop by inner loop.
+  auto outerCandidateIter = candidateSliceOpList.rbegin();
+  for (auto [outerLoop, innerLoop] :
+       llvm::zip_equal(MutableArrayRef(newOuterLoops).drop_back(),
+                       MutableArrayRef(newOuterLoops).drop_front())) {
+    // create insertSliceOp according outer candidateSliceOp
+    if (outerCandidateIter != candidateSliceOpList.rend() &&
+        outerCandidateIter->getOperation()
+                ->getParentOfType<LoopLikeOpInterface>() == outerLoop) {
+      if (auto forallOp = dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
+        rewriter.setInsertionPoint(forallOp.getTerminator());
+      } else {
+        rewriter.setInsertionPointAfter(*outerCandidateIter);
+      }
+
+      if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+              rewriter, clonedConsumerOp, operandNumber, *outerCandidateIter,
+              allResultOffsets, allResultSizes))) {
+        return failure();
+      }
+
+      if (auto forOp = dyn_cast<scf::ForOp>(outerLoop.getOperation())) {
+        fixTerminatorSCFYield(
+            rewriter, forOp,
+            innerLoop->getResults().take_back(newInitAppend.size()),
+            allResultOffsets, allResultSizes,
+            forOp.getBody()->getArguments().take_back(newInitAppend.size()));
+      } else if (auto forallOp =
+                     dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
+        fixTerminatorSCFInParallel(
+            rewriter, forallOp,
+            innerLoop->getResults().take_back(newInitAppend.size()),
+            allResultOffsets, allResultSizes,
+            forallOp.getBody()->getArguments().take_back(newInitAppend.size()));
+      }
+      outerCandidateIter++;
+    } else {
+      // yield additional new appended results of innerLoop
+      assert(isa<scf::ForOp>(outerLoop));
+      auto forOp = cast<scf::ForOp>(outerLoop);
+      auto outerLoopYield =
+          cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+      SmallVector<Value> newYields =
+          llvm::to_vector(outerLoopYield.getOperands());
+      ValueRange additionalYields =
+          innerLoop->getResults().take_back(newInitAppend.size());
+      newYields.append(additionalYields.begin(), additionalYields.end());
+      rewriter.setInsertionPoint(outerLoopYield);
+      rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
+    }
   }
 
+  // 11. Replace the result of consumer op with new outerMost loop's
+  // results.
   for (auto &&[oldResult, newResult] :
        llvm::zip(consumerOp->getResults(),
-                 newLoopOp->getResults().drop_front(initSize))) {
+                 newOuterLoops.front()->getResults().take_back(
+                     newInitAppend.size()))) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
-  // 13. Need to erase the old scf loop and the cloned consumer op.
-  rewriter.eraseOp(oldLoopOp);
+  // 12. Need to erase the cloned consumer op.
   rewriter.eraseOp(clonedConsumerOp);
 
   return scf::SCFFuseConsumerOfSliceResult{
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 400b558e37fcd..d60d5f4fd7b3c 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -315,3 +315,99 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+  func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+      %iv0 = affine.apply #map(%arg3)
+      %iv1 = affine.apply #map(%arg4)
+      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+          scf.yield %insert_slice : tensor<128x128xf32>
+        }
+        scf.yield %3 : tensor<128x128xf32>
+      }
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+      }
+    }
+    %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    return %5 : tensor<256x256xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+//      CHECK: #[[MAP0:.*]] =  affine_map<(d0) -> (d0 * 128)>
+//      CHECK: #[[MAP1:.*]] =  affine_map<(d0, d1) -> (d0 + d1 * 128)>
+//      CHECK: func.func @fuse_tilable_consumer_nested_scf_loop(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[dest1:.*]] = linalg.fill
+// CHECK-SAME:          outs(%[[dest0]] :
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG0:.*]] = %[[dest1]], %[[SECOND_OUT_ARG0:.*]] = %[[dest0]])
+// CHECK-SAME:   {
+//      CHECK:      %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
+//      CHECK:      %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
+//      CHECK:      %[[MAT_OUT_SLICE0:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
+//      CHECK:      %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
+//      CHECK:      %[[ADD_OUT_SLICE0:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV3:.*]] = %[[C0]]
+// CHECK-SAME:          iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[MAT_OUT_SLICE0]], %[[SECOND_OUT_ARG1:.*]] = %[[ADD_OUT_SLICE0]])
+// CHECK-SAME:      {
+//      CHECK:          %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV4:.*]] = %[[C0]]
+// CHECK-SAME:            iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
+// CHECK-SAME:          {
+//      CHECK:            %[[MAT_OUT_SLICE1:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
+//      CHECK:            %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
+//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:                  outs(%[[MAT_OUT_SLICE1]] :
+//      CHECK:            %[[REAL_SECOND_IV1:.*]] = affine.apply #[[MAP1]](%[[IV3]], %[[IV1]])
+//      CHECK:            %[[REAL_SECOND_IV2:.*]] = affine.apply #[[MAP1]](%[[IV4]], %[[IV2]])
+//      CHECK:            %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[REAL_SECOND_IV1]], %[[REAL_SECOND_IV2]]] [64, 64] [1, 1]
+//      CHECK:            %[[ADD_OUT_SLICE1:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
+// CHECK-SAME:              outs(%[[ADD_OUT_SLICE1]] :
+//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
+//      CHECK:          }
+//      CHECK:          scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
+//      CHECK:      }
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]]#1 into %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]]#0 into %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#1 :



More information about the Mlir-commits mailing list