[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 5 06:29:54 PDT 2024


https://github.com/Yun-Fly updated https://github.com/llvm/llvm-project/pull/94190

>From 9c04ad4f082145c79ab521a6c57854f816b817f3 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Sun, 2 Jun 2024 23:32:22 -0700
Subject: [PATCH] extend consumer fuse to nested scf loop

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 731 ++++++++++++------
 .../tile-and-fuse-consumer.mlir               |  96 +++
 2 files changed, 606 insertions(+), 221 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..824c366c014f2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1103,98 +1104,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
 // tileAndFuseConsumerUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
-/// A utility function that checks whether the only use of the result of a
-/// tensor.insert_slice op is in a scf.yield op.
-static LogicalResult
-checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
-  Value result = candidateSliceOp.getResult();
-  Value::use_range uses = result.getUses();
-  if (!llvm::hasSingleElement(uses)) {
-    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
-    return failure();
-  }
-  OpOperand &operandUse = (*uses.begin());
-  Operation *userOp = operandUse.getOwner();
-  if (!isa<scf::YieldOp>(userOp)) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "Expected scf.yield to be the only user, but got -> "
-               << (*userOp));
-    return failure();
-  }
-  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
-    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
-                               "be in the same block\n");
-    return failure();
-  }
-  return success();
-}
-
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
-                                                  Block *containingOpBlock) {
-  // Step 1. Check that the value has exactly one use.
-  if (!llvm::hasSingleElement(val.getUses()))
-    return failure();
-  // Step 2. Get uses.
-  OpOperand &operand = (*val.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.for, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp))
-    return failure();
-  if (containingOpBlock != consumerOp->getBlock())
-    return failure();
-  return &operand;
-}
-
-/// Fetch the untiled consumer of a scf.for's result which is yielded by a
-/// tensor.insert_slice. This function makes the following assumptions :
-/// 1.  tensor.insert_slice has scf.yield as its only user.
-/// 2.  scf.for's corresponding result has only one use.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
-  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
-    return failure();
-  Value sliceResult = candidateSliceOp.getResult();
-  // Step 1. Fetch the corresponding output.
-  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
-  unsigned resultNumber = yieldOpOperand.getOperandNumber();
-  // Step 2. Check containing op is scf.for.
-  Operation *containingOp = candidateSliceOp->getParentOp();
-  auto forOp = dyn_cast<scf::ForOp>(containingOp);
-  if (!forOp)
-    return failure();
-  Value resultingValue = forOp->getResult(resultNumber);
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
-  // Step 1. Fetch the corresponding output
-  Value sliceDest = candidateSliceOp.getDest();
-  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
-  if (!iterArg)
-    return failure();
-  Operation *containingOp = iterArg.getOwner()->getParentOp();
-  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
-    return failure();
-  // Step 2. Check that the containing op is scf.forall.
-  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
-  if (!forallOp)
-    return failure();
-  Value resultingValue =
-      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
 /// This utility currently checks whether the loop either :-
 /// 1. Yields exactly one result.
 /// 2. Has consumer op as its first user and other users to be in the same
@@ -1220,31 +1129,117 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
   return success();
 }
 
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
-  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(insertSlice);
-  } else if (auto parallelInsertSlice =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(parallelInsertSlice);
-  } else {
+/// Traverse and collect all outer loops of given sliceOp, sorted by
+/// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+    Operation *sliceOp,
+    std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+  assert(isa<OffsetSizeAndStrideOpInterface>(sliceOp));
+  SmallVector<LoopLikeOpInterface> outerLoops;
+  auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+  while (forOp) {
+    outerLoops.push_back(forOp);
+    if (untilLoop.has_value() && *untilLoop == forOp)
+      break;
+    forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+  }
+  return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+/// Get the result of top-level loop which yields the target InsertSliceOp. E.g
+/// ```
+/// %1 = scf.for
+///  %2 = scf.for
+///   %3 = scf.for
+///      ...
+///      %4 = insert
+///      yield %4
+///   %5 = insert %3
+///   yield %5
+///  yield %2
+/// ```
+/// @param targetSliceOp: %4 = insert
+/// @return first: resultValue: %1
+///         second: Collected insertSliceOp List during walk including
+///                   targetSliceOp: %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(Operation *targetSliceOp,
+                                          int curDepth = 0, int maxDepth = 5) {
+  assert(isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+  // Control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+  candidateSliceOpList.push_back(
+      cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+  Value resultOfLoop;
+  if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
+    Value destValue = sliceOp.getDest();
+    auto iterArg = cast<BlockArgument>(destValue);
+    auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+    if (!forallOp)
+      return failure();
+    resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
+    Value resultValue = sliceOp.getResult();
+    for (auto &useOperand : resultValue.getUses()) {
+      if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        if (llvm::detail::isPresent(resultOfLoop))
+          return failure();
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+      }
+    }
+  }
+
+  if (!llvm::detail::isPresent(resultOfLoop))
     return failure();
+
+  while (true) {
+    bool walkThroughOuterLoop = false;
+    for (OpOperand &useOperand : resultOfLoop.getUses()) {
+      if (auto sliceOp =
+              dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+        FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+            resultAndSliceOpsPair = getResultOfTopLevelLoopYieldInsertSliceOp(
+                sliceOp, curDepth + 1);
+        if (failed(resultAndSliceOpsPair))
+          return failure();
+        candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
+                                    (*resultAndSliceOpsPair).second.end());
+        return std::make_pair((*resultAndSliceOpsPair).first,
+                              candidateSliceOpList);
+      } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        // walk through outer loop
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+        walkThroughOuterLoop = true;
+        break;
+      }
+    }
+    if (!walkThroughOuterLoop)
+      break;
   }
+  return std::make_pair(resultOfLoop, candidateSliceOpList);
 }
 
 /// After fusing consumer into scf.for we want to modify the scf.yield operation
 /// to reflect the same by returning the values yielded by the tiled consumer.
 static void
 fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
-                      TilingResult &tilingResult,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                      ResultRange tilingResult,
+                      ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+                      ArrayRef<SmallVector<OpFoldResult>> resultSizes,
                       ArrayRef<BlockArgument> bbArgs) {
   scf::YieldOp oldTerminatorOp =
       cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
   unsigned totalOldResults = oldTerminatorOp->getNumResults();
-  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+  unsigned totalTiledResults = tilingResult.size();
   SmallVector<Value> newYieldOperands;
   newYieldOperands.reserve(totalOldResults + totalTiledResults);
   for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1253,8 +1248,7 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
   rewriter.setInsertionPointAfter(oldTerminatorOp);
   Location loc = newForOp.getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
-                       resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1269,16 +1263,16 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
 /// values by the tiled consumer within scf.forall.in_parallel region.
 static void
 fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
-                           SmallVector<Value> tiledResults,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                           ResultRange tilingResult,
+                           ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> resultSizes,
                            ArrayRef<BlockArgument> bbArgs) {
   scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
   rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
   Location firstYieldOpLoc =
       (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1286,6 +1280,176 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
   }
 }
 
+/// If the top level loop of nested loop structure is scf.forall, need to create
+/// additional tensor.extract_slice for its new appended `shared_outs` in order
+/// to pass correct local memory for inner loops. E.g.
+///
+/// scf.forall shared_outs(%o1=..., %o2=...) {
+///     %local_o1 = extract_slice %o1
+///     // fix new appended `shared_out` %o2
+///     %local_o2 = extract_slice %o2
+///     scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+///        ...
+///     }
+///     ...
+/// }
+static SmallVector<tensor::ExtractSliceOp> fixLoopInitFromSharedOutSCFForall(
+    RewriterBase &rewriter, Operation *loop, ArrayRef<Value> newSharedOuts,
+    ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+    ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
+  rewriter.setInsertionPoint(loop);
+  Location loc = loop->getLoc();
+  // create new ExtractOps for newInits from scf.forall
+  SmallVector<tensor::ExtractSliceOp> newExtractOps;
+  newExtractOps.reserve(resultOffsets.size());
+  for (auto [bbArg, offset, sizes] :
+       llvm::zip_equal(newSharedOuts, resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        loc, bbArg, offset, sizes, strides);
+    newExtractOps.push_back(newExtractOp);
+  }
+  return newExtractOps;
+}
+
+/// If outerMost loop of nested loop structure is `scf.forall`, need to deal
+/// with DpsInit of tiled consumer
+static void
+fixDpsInitsOfTiledConsumer(RewriterBase &rewriter, Operation *tiledConsumer,
+                           ArrayRef<BlockArgument> bbArgs,
+                           ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
+  rewriter.setInsertionPoint(tiledConsumer);
+  Location loc = tiledConsumer->getLoc();
+  for (auto &&[bbArg, offset, sizes, dpsInit] :
+       llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+                       cast<DestinationStyleOpInterface>(tiledConsumer)
+                           .getDpsInitsMutable())) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        loc, bbArg, offset, sizes, strides);
+    dpsInit.set(newExtractOp.getResult());
+  }
+}
+
+/// Compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+    RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+    OffsetSizeAndStrideOpInterface ossSliceOp,
+    SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+  // 1. Check all stride all 1
+  if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+  }
+  // 2. Compute iteration domain tile from the input position
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  if (failed(tilableOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
+          ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        tilableOp, "can't get iter domain position from input position");
+  }
+  unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  // 3. Compute result Tile by resultNumber
+  for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
+    if (failed(tilableOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
+      return rewriter.notifyMatchFailure(
+          tilableOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+  allResultOffsets = resultOffsets;
+  allResultSizes = resultSizes;
+  return success();
+}
+
+/// Considering multi-level tensor.*SliceOp maybe based on different
+/// coordination, this utility computes the real OFFSET coordinated on ROOT
+/// SliceOp. E.g
+///             %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
+///         %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
+///
+/// where the coordination can be illustrated as follow:
+///
+///  %3 ----------------------------------
+///  |         |         |
+///  | OFFSET2 | OFFSET1 |
+///  | ------ %0         |
+///  |                   |
+///  |                   |
+///  |------------------ %1 ------ |
+///  |                   |  SIZE1  |
+///  |                   |         |
+///  |                   |         |
+///  |                   | ------- |
+///  |
+///
+/// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
+static FailureOr<SmallVector<OpFoldResult>>
+computeRealOffsetsCoordinatedRootSliceOp(
+    RewriterBase &rewriter, Location loc,
+    OffsetSizeAndStrideOpInterface candidateSliceOp,
+    MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
+  if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "candidateSliceOp has stride");
+  }
+  SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
+  // Real offsets equals to accumulative offsets of outer candidates
+  for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
+       iter++) {
+    // assert each outer candidate slice has no stride
+    if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
+          return !isConstantIntValue(stride, 1);
+        })) {
+      return failure();
+    }
+    for (auto &&[ofr1, ofr2] :
+         llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
+      using AVE = affine::AffineValueExpr;
+      affine::AffineBuilder ab(rewriter, loc);
+      AffineExpr dim0, dim1, sym;
+      bindDims(rewriter.getContext(), dim0, dim1);
+      bindSymbols(rewriter.getContext(), sym);
+      auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
+      ofr1 = ab.add(aveOffset1, aveOffset2);
+    }
+  }
+  return realOffsets;
+}
+
+/// Get the first tilable user of given Value and check its domination at the
+/// same time
+static FailureOr<OpOperand *>
+getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
+  for (auto &useOfval : val.getUses()) {
+    Operation *consumerOp = useOfval.getOwner();
+    // 1. Check whether consumerOp is tilable
+    if (!isa<TilingInterface>(consumerOp) ||
+        !isa<DestinationStyleOpInterface>(consumerOp))
+      continue;
+    // 2. Check stay in same block with loopOp
+    if (loopOp->getBlock() != consumerOp->getBlock())
+      continue;
+    // 3. Check no other user before it
+    if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
+      continue;
+    }
+    return &useOfval;
+  }
+  return failure();
+}
+
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1297,10 +1461,28 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
 
-  // 1. Get the consumer of scf.for for the result yielded by
-  // tensor.insert_slice/parallel_insert_slice.
+  // 1.a. Get the real consumer of candidate
+  // tensor.insert_slice/parallel_insert_slice by walking through
+  // scf.for/scf.forall and collect all [Parallel]insertSliceOp(s) along the
+  // way.
+  FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+      resultAndSliceOpsPair =
+          getResultOfTopLevelLoopYieldInsertSliceOp(candidateSliceOp);
+  if (failed(resultAndSliceOpsPair)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+
+  // 1.b. Get all outer loops of candidateSliceOp.
+  SmallVector<LoopLikeOpInterface> outerLoops = getOuterLoopsOfSliceOp(
+      candidateSliceOp, dyn_cast<OpResult>((*resultAndSliceOpsPair).first)
+                            .getDefiningOp<LoopLikeOpInterface>());
+  LoopLikeOpInterface outerMostLoop = outerLoops.front();
+
+  // 2. Get first tilable consumer op
   FailureOr<OpOperand *> maybeConsumerOpOperand =
-      getUntiledConsumerFromSlice(candidateSliceOp);
+      getTilableConsumerOperandFirstUseVal((*resultAndSliceOpsPair).first,
+                                           outerMostLoop);
   if (failed(maybeConsumerOpOperand)) {
     return rewriter.notifyMatchFailure(candidateSliceOp,
                                        "could not fetch consumer to fuse");
@@ -1316,111 +1498,193 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
   }
 
-  Operation *oldLoopOp = nullptr;
-  SmallVector<Value> newOuts;
-  Block *oldLoopBody = nullptr;
-  unsigned initSize = 0;
-  unsigned rank = 1;
-  if (isInsertSliceOp) {
-    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
-    oldLoopOp = forOp;
-    llvm::append_range(newOuts, forOp.getInits());
-    oldLoopBody = forOp.getBody();
-    initSize = forOp.getInits().size();
-  } else {
-    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
-    oldLoopOp = forallOp;
-    llvm::append_range(newOuts, forallOp.getOutputs());
-    oldLoopBody = forallOp.getBody();
-    initSize = forallOp.getOutputs().size();
-    rank = forallOp.getRank();
-  }
-
-  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
-    return rewriter.notifyMatchFailure(
-        oldLoopOp, "containing loop op should either yield just one value or "
-                   "have the consumer op as its first user");
-  }
-
-  OpBuilder::InsertionGuard g(rewriter);
-
-  // 2. Check consumer is not using scf loop's output as init.
+  // 3. Check consumer is not using outerMostLoop's output as init.
   auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  newOuts.append(dpsInits);
 
-  Location loc = oldLoopOp->getLoc();
+  // 4.a. Reconstruct nested loop from outer to inner.
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList =
+      (*resultAndSliceOpsPair).second;
+  SmallVector<LoopLikeOpInterface> newOuterLoops;
+  SmallVector<SmallVector<OpFoldResult>> allResultOffsets, allResultSizes;
+  SmallVector<tensor::ExtractSliceOp> newExtractOps;
 
-  // 3. Create new scf loop op.
-  rewriter.setInsertionPoint(consumerOp);
-  Operation *newLoopOp = nullptr;
+  Block *oldLoopBody = nullptr;
   Block *newLoopBody = nullptr;
+  SmallVector<Value> newInitAppend = dpsInits, newOuts;
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // 4.b. Set insertPoint right before consumerOp
+  rewriter.setInsertionPoint(consumerOp);
+
+  for (auto [index, loop] :
+       llvm::enumerate(MutableArrayRef(outerLoops).drop_back())) {
+    if (index > 0) {
+      rewriter.setInsertionPoint(loop);
+      // 4.c. Create `extractSliceOp` for newInits if they comes from sharedOut
+      // of previous `scf.forall` loop.
+      if (auto prevOuterLoop =
+              dyn_cast<scf::ForallOp>(newOuterLoops.back().getOperation())) {
+        if (index != 1) {
+          return rewriter.notifyMatchFailure(
+              prevOuterLoop, "Currently only outerMostLoop assumed forallOp");
+        }
+        OffsetSizeAndStrideOpInterface outerMostCandidate =
+            candidateSliceOpList.back();
+        if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+                rewriter, cast<TilingInterface>(consumerOp), operandNumber,
+                outerMostCandidate, allResultOffsets, allResultSizes))) {
+          return failure();
+        }
+        newExtractOps = fixLoopInitFromSharedOutSCFForall(
+            rewriter, loop, newInitAppend, allResultOffsets, allResultSizes);
+        newInitAppend = llvm::map_to_vector(
+            newExtractOps,
+            [](tensor::ExtractSliceOp op) -> Value { return op.getResult(); });
+      }
+    }
+    LoopLikeOpInterface newLoopOp;
+    // 4.d. Create a new loop with the new init values for this loop.
+    if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
+      newOuts = llvm::to_vector(forOp.getInits());
+      newOuts.append(newInitAppend.begin(), newInitAppend.end());
+      auto newLoop = rewriter.create<scf::ForOp>(
+          forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+          forOp.getStep(), newOuts);
+      newLoopOp = newLoop;
+      oldLoopBody = forOp.getBody();
+      newLoopBody = newLoop.getBody();
+    } else if (auto forallOp = dyn_cast<scf::ForallOp>(loop.getOperation())) {
+      newOuts = llvm::to_vector(forallOp.getOutputs());
+      newOuts.append(newInitAppend.begin(), newInitAppend.end());
+      auto newLoop = rewriter.create<scf::ForallOp>(
+          forallOp.getLoc(), forallOp.getMixedLowerBound(),
+          forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+          forallOp.getMapping());
+      rewriter.eraseOp(newLoop.getTerminator());
+      newLoopOp = newLoop;
+      oldLoopBody = forallOp.getBody();
+      newLoopBody = newLoop.getBody();
+    }
+    newInitAppend = llvm::map_to_vector(
+        newLoopBody->getArguments().take_back(newInitAppend.size()),
+        [](BlockArgument bArg) -> Value { return bArg; });
+    rewriter.mergeBlocks(
+        oldLoopBody, newLoopBody,
+        newLoopBody->getArguments().take_front(oldLoopBody->getNumArguments()));
+    rewriter.replaceOp(
+        loop, newLoopOp->getResults().take_front(loop->getNumResults()));
+    newOuterLoops.push_back(newLoopOp);
+  }
+
+  // 5.a. Reconstruct inner-most loop.
+  LoopLikeOpInterface oldInnerMostLoop = outerLoops.back(), newInnerMostLoop;
+  Location loc = oldInnerMostLoop->getLoc();
+  if (outerLoops.size() > 1)
+    rewriter.setInsertionPoint(oldInnerMostLoop);
+
   if (isInsertSliceOp) {
-    auto forOp = cast<scf::ForOp>(oldLoopOp);
+    auto forOp = cast<scf::ForOp>(oldInnerMostLoop.getOperation());
+    newOuts = llvm::to_vector(forOp.getInits());
+    newOuts.append(newInitAppend.begin(), newInitAppend.end());
+    oldLoopBody = forOp.getBody();
     auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
                                                 forOp.getUpperBound(),
                                                 forOp.getStep(), newOuts);
-    newLoopOp = newForOp;
+    newInnerMostLoop = newForOp;
     newLoopBody = newForOp.getBody();
   } else {
-    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+    auto forallOp = cast<scf::ForallOp>(oldInnerMostLoop.getOperation());
+    newOuts = llvm::to_vector(forallOp.getOutputs());
+    newOuts.append(newInitAppend.begin(), newInitAppend.end());
+    oldLoopBody = forallOp.getBody();
     auto newForallOp = rewriter.create<scf::ForallOp>(
         loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
         forallOp.getMixedStep(), newOuts, forallOp.getMapping());
-    newLoopOp = newForallOp;
+    newInnerMostLoop = newForallOp;
     rewriter.eraseOp(newForallOp.getTerminator());
     newLoopBody = newForallOp.getBody();
   }
 
-  // 4. Move the loop body to the new op.
+  // 5.b. Move the loop body to the new op.
   unsigned oldNumArguments = oldLoopBody->getNumArguments();
   rewriter.mergeBlocks(oldLoopBody, newLoopBody,
                        newLoopBody->getArguments().take_front(oldNumArguments));
+  // 5.c. Replace the result of old oldInnerMostLoop with newInnerMostLoop's
+  // results.
+  rewriter.replaceOp(oldInnerMostLoop,
+                     newInnerMostLoop->getResults().take_front(
+                         oldInnerMostLoop->getNumResults()));
 
-  // 5. Set insertion point before terminator op of the loop and create a new
+  // 6. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
   tensor::InsertSliceOp clonedInsertSliceOp;
   if (auto sliceOp =
           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
-    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
     rewriter.setInsertionPoint(newForallOp.getTerminator());
-    clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
   } else {
     rewriter.setInsertionPoint(candidateSliceOp);
-    clonedInsertSliceOp =
-        cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
   }
 
-  // 6.a. Clone consumer op.
-  auto newForOpBlockArgsForConsumerDest =
-      newLoopBody->getArguments().drop_front(oldNumArguments);
+  // 7.a. Due to current assumption of `getTiledImplementation` that all
+  // `Operands` are untiled with original tensor size, create dummy
+  // `insertSliceOp` to align with that requirement.
+  auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
+  FailureOr<SmallVector<OpFoldResult>> realOffsets =
+      computeRealOffsetsCoordinatedRootSliceOp(rewriter, loc, ossSliceOp,
+                                               candidateSliceOpList);
+  if (failed(realOffsets))
+    return failure();
+  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, candidateSliceOp->getOperand(0),
+      candidateSliceOpList.back()->getOperand(1), *realOffsets,
+      ossSliceOp.getMixedSizes(), ossSliceOp.getMixedStrides());
+
+  SmallVector<Value> newDpsInitsForConsumerDest = llvm::map_to_vector(
+      newLoopBody->getArguments().drop_front(oldNumArguments),
+      [](BlockArgument bArg) -> Value { return bArg; });
+
+  // 7.b. If the outerMostLoop is scf.forall, then the `newExtractOps` has been
+  // additionally created at `step 4.e` for `dpsInits`. As the counterpart, the
+  // `insertSliceOp` is also needed for the same purpose with `step 7.a`.
+  if (!newExtractOps.empty()) {
+    for (auto &&[extractOp, newDpsInit] :
+         llvm::zip_equal(newExtractOps, newDpsInitsForConsumerDest)) {
+      auto alignDpsInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+          loc, newDpsInit, extractOp.getSource(), extractOp.getMixedOffsets(),
+          extractOp.getMixedSizes(), extractOp.getMixedStrides());
+      newDpsInit = alignDpsInsertSliceOp.getResult();
+    }
+  }
+
+  // 8.a. Clone consumer op.
   auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
-      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+      rewriter, consumerOp, newDpsInitsForConsumerDest));
 
-  // 6.b. Replace all uses of the loop result with the result of the cloned
+  // 8.b. Replace all uses of the loop result with the result of the cloned
   // tensor.insert_slice.
   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
     operandToReplace.set(clonedInsertSliceOp.getResult());
   });
 
-  // 7 - Perform tiling of the cloned consumer and replace the operand at
+  // 9. Perform tiling of the cloned consumer and replace the operand at
   // `operandNumber` with the source of the cloned tensor.insert_slice op.
-  auto ossSliceOp =
-      cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
   FailureOr<TilingResult> tileAndFuseResult =
       tensor::replaceInsertSliceWithTiledConsumer(
-          rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(
+              clonedInsertSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
   if (failed(tileAndFuseResult)) {
     return failure();
   }
@@ -1428,75 +1692,100 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
       tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
       clonedInsertSliceOp.getSource());
 
-  // 8 - Extract offset/sizes/strides required to create the
-  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
-  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
-  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
-  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
-  // 9. Check all insert stride is 1.
-  if (llvm::any_of(strides, [](OpFoldResult stride) {
-        return !isConstantIntValue(stride, 1);
-      })) {
-    return rewriter.notifyMatchFailure(
-        candidateSliceOp, "containingOp's result yield with stride");
-  }
-
-  // 10. Try to get iter domain position from input position.
-  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-          iterDomainSizes))) {
-    return rewriter.notifyMatchFailure(
-        clonedConsumerOp, "can't get iter domain position from input position");
-  }
-
-  // 11. Try to fetch the offset and size for all results of the cloned
+  // 10. Try to fetch the offset and size for all results of the cloned
   // consumer. This would then be used to form the corresponding
   // tensor.insert_slice/parallel_insert_slice later.
-  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
-  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-      totalNumResultsOfConsumer);
-  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
-  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
-    if (failed(clonedConsumerOp.getResultTilePosition(
-            rewriter, idx, iterDomainOffsets, iterDomainSizes,
-            resultOffsets[idx], resultSizes[idx]))) {
-      return rewriter.notifyMatchFailure(
-          clonedConsumerOp,
-          "can't get result domain position from iter domain position");
-    }
+  if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+          rewriter, clonedConsumerOp, operandNumber, ossSliceOp,
+          allResultOffsets, allResultSizes))) {
+    return failure();
+  }
+
+  if (!newExtractOps.empty()) {
+    fixDpsInitsOfTiledConsumer(
+        rewriter, tileAndFuseResult->tiledOps[0],
+        newLoopBody->getArguments().drop_front(oldNumArguments),
+        allResultOffsets, allResultSizes);
   }
 
-  auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
-  auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
   if (isInsertSliceOp) {
-    auto newForOp = cast<scf::ForOp>(newLoopOp);
+    auto newForOp = cast<scf::ForOp>(newInnerMostLoop);
     fixTerminatorSCFYield(
-        rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
-        newForOp.getBody()->getArguments().drop_front(1 + initSize));
+        rewriter, newForOp, tileAndFuseResult->tiledOps[0]->getResults(),
+        allResultOffsets, allResultSizes,
+        newForOp.getBody()->getArguments().take_back(newInitAppend.size()));
   } else {
-    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
     fixTerminatorSCFInParallel(
         rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
-        arrayRefOffsets, arrayRefSizes,
-        newForallOp.getBody()->getArguments().drop_front(rank + initSize));
+        allResultOffsets, allResultSizes,
+        newForallOp.getBody()->getArguments().take_back(newInitAppend.size()));
   }
 
-  // 12. Replace the result of scf loop and consumer op with new loop's results.
-  for (auto &&[oldResult, newResult] :
-       llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
-    rewriter.replaceAllUsesWith(oldResult, newResult);
+  newOuterLoops.push_back(cast<LoopLikeOpInterface>(newInnerMostLoop));
+
+  // 11.a. Reconstruct terminator of outer loop by inner loop.
+  auto outerCandidateIter = candidateSliceOpList.rbegin();
+  for (auto [outerLoop, innerLoop] :
+       llvm::zip_equal(MutableArrayRef(newOuterLoops).drop_back(),
+                       MutableArrayRef(newOuterLoops).drop_front())) {
+    // 11.b. Create insertSliceOp according outer candidateSliceOp
+    if (outerCandidateIter != candidateSliceOpList.rend() &&
+        outerCandidateIter->getOperation()
+                ->getParentOfType<LoopLikeOpInterface>() == outerLoop) {
+      if (auto forallOp = dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
+        rewriter.setInsertionPoint(forallOp.getTerminator());
+      } else {
+        rewriter.setInsertionPointAfter(*outerCandidateIter);
+      }
+
+      if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+              rewriter, clonedConsumerOp, operandNumber, *outerCandidateIter,
+              allResultOffsets, allResultSizes))) {
+        return failure();
+      }
+
+      if (auto forOp = dyn_cast<scf::ForOp>(outerLoop.getOperation())) {
+        fixTerminatorSCFYield(
+            rewriter, forOp,
+            innerLoop->getResults().take_back(newInitAppend.size()),
+            allResultOffsets, allResultSizes,
+            forOp.getBody()->getArguments().take_back(newInitAppend.size()));
+      } else if (auto forallOp =
+                     dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
+        fixTerminatorSCFInParallel(
+            rewriter, forallOp,
+            innerLoop->getResults().take_back(newInitAppend.size()),
+            allResultOffsets, allResultSizes,
+            forallOp.getBody()->getArguments().take_back(newInitAppend.size()));
+      }
+      outerCandidateIter++;
+    } else {
+      // 11.c. Yield additional new appended results of innerLoop
+      assert(isa<scf::ForOp>(outerLoop));
+      auto forOp = cast<scf::ForOp>(outerLoop);
+      auto outerLoopYield =
+          cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+      SmallVector<Value> newYields =
+          llvm::to_vector(outerLoopYield.getOperands());
+      ValueRange additionalYields =
+          innerLoop->getResults().take_back(newInitAppend.size());
+      newYields.append(additionalYields.begin(), additionalYields.end());
+      rewriter.setInsertionPoint(outerLoopYield);
+      rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
+    }
   }
 
+  // 12. Replace the result of consumer op with new outerMost loop's
+  // results.
   for (auto &&[oldResult, newResult] :
        llvm::zip(consumerOp->getResults(),
-                 newLoopOp->getResults().drop_front(initSize))) {
+                 newOuterLoops.front()->getResults().take_back(
+                     newInitAppend.size()))) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
-  // 13. Need to erase the old scf loop and the cloned consumer op.
-  rewriter.eraseOp(oldLoopOp);
+  // 13. Need to erase the cloned consumer op.
   rewriter.eraseOp(clonedConsumerOp);
 
   return scf::SCFFuseConsumerOfSliceResult{
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 400b558e37fcd..d60d5f4fd7b3c 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -315,3 +315,99 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+  func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+      %iv0 = affine.apply #map(%arg3)
+      %iv1 = affine.apply #map(%arg4)
+      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+          scf.yield %insert_slice : tensor<128x128xf32>
+        }
+        scf.yield %3 : tensor<128x128xf32>
+      }
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+      }
+    }
+    %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    return %5 : tensor<256x256xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+//      CHECK: #[[MAP0:.*]] =  affine_map<(d0) -> (d0 * 128)>
+//      CHECK: #[[MAP1:.*]] =  affine_map<(d0, d1) -> (d0 + d1 * 128)>
+//      CHECK: func.func @fuse_tilable_consumer_nested_scf_loop(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[dest1:.*]] = linalg.fill
+// CHECK-SAME:          outs(%[[dest0]] :
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG0:.*]] = %[[dest1]], %[[SECOND_OUT_ARG0:.*]] = %[[dest0]])
+// CHECK-SAME:   {
+//      CHECK:      %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
+//      CHECK:      %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
+//      CHECK:      %[[MAT_OUT_SLICE0:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
+//      CHECK:      %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
+//      CHECK:      %[[ADD_OUT_SLICE0:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV3:.*]] = %[[C0]]
+// CHECK-SAME:          iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[MAT_OUT_SLICE0]], %[[SECOND_OUT_ARG1:.*]] = %[[ADD_OUT_SLICE0]])
+// CHECK-SAME:      {
+//      CHECK:          %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV4:.*]] = %[[C0]]
+// CHECK-SAME:            iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
+// CHECK-SAME:          {
+//      CHECK:            %[[MAT_OUT_SLICE1:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
+//      CHECK:            %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
+//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:                  outs(%[[MAT_OUT_SLICE1]] :
+//      CHECK:            %[[REAL_SECOND_IV1:.*]] = affine.apply #[[MAP1]](%[[IV3]], %[[IV1]])
+//      CHECK:            %[[REAL_SECOND_IV2:.*]] = affine.apply #[[MAP1]](%[[IV4]], %[[IV2]])
+//      CHECK:            %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[REAL_SECOND_IV1]], %[[REAL_SECOND_IV2]]] [64, 64] [1, 1]
+//      CHECK:            %[[ADD_OUT_SLICE1:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
+// CHECK-SAME:              outs(%[[ADD_OUT_SLICE1]] :
+//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
+//      CHECK:          }
+//      CHECK:          scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
+//      CHECK:      }
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]]#1 into %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]]#0 into %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#1 :



More information about the Mlir-commits mailing list