[Mlir-commits] [mlir] [mlir][SCF] Add `scf::tileAndFuseConsumer` that tiles a consumer into a given tiled loop nest. (PR #167634)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 20 11:52:02 PST 2025


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/167634

>From 640cd804bd655fe1aabc8c6e78e4e8a3f0529dfa Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Mon, 10 Nov 2025 21:42:24 -0800
Subject: [PATCH 1/4] [mlir][SCF] Add `scf::tileAndFuseConsumer` that tiles a
 consumer into a given tiled loop nest.

The existing `scf::tileAndFuseConsumerOfSlices` takes a list of slices
(and loops they are part of), tries to find the consumer of these
slices (all slices are expected to be the same consumer), and then
tiles the consumer into the loop nest using the `TilingInterface`. A
more natural way of doing consumer fusion is to just start from the
consumer, look for operands that are produced by the loop nest passed
in as `loops` (presumably these loops are generated by tiling, but
that is not a requirement for consumer fusion). Using the consumer you
can find the slices of the operands that are accessed within the loop
which you can then use to tile and fuse the consumer (using
`TilingInterface`). This handles more naturally the case where
multiple operands of the consumer come from the loop nest.

The `scf::tileAndFuseConsumerOfSlices` was implemented as a mirror of
`scf::tileAndFuseProducerOfSlice`. For the latter, the slice has a
single producer for the source of the slice, which makes it a natural
way of specifying producer fusion. But for consumers, the result might
have multiple users, resulting in multiple candidates for fusion, as
well as a fusion candidate using multiple results from the tiled loop
nest. This means using slices
(`tensor.insert_slice`/`tensor.parallel_insert_slice`) as a hook for
consumer fusion turns out to be quite hard to navigate. The use of the
consumer directly avoids all those pain points. In time the
`scf::tileAndFuseConsumerOfSlices` should be deprecated in favor of
`scf::tileAndFuseConsumer`. There is a lot of tech-debt that has
accumulated in `scf::tileAndFuseConsumerOfSlices` that needs to be
cleanedup. So while that gets cleaned up, and required functionality
is moved to `scf::tileAndFuseConsumer`, the old path is still
maintained.

The test for `scf::tileAndFuseConsumerUsingSlices` is copied to
`tile-and-fuse-consumer.mlir` to
`tile-and-fuse-consumer-using-slices.mlir`. All the tests that were
there in this file are now using the `tileAndFuseConsumer` method. The
test op `test.tile_and_fuse_consumer` is modified to call
`scf::tileAndFuseConsumer`, while a new op
`test.tile_and_fuse_consumer_of_slice` is used to keep the old path
tested while it is deprecated.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |    5 +
 .../SCF/Transforms/TileUsingInterface.h       |   12 +
 .../SCF/Transforms/TileUsingInterface.cpp     |  221 +++-
 .../transform-tile-and-fuse-pack-unpack.mlir  |    4 +-
 .../tile-and-fuse-consumer-using-slices.mlir  | 1156 +++++++++++++++++
 .../tile-and-fuse-consumer.mlir               |  380 +++---
 .../TestTilingInterfaceTransformOps.cpp       |   79 +-
 .../TestTilingInterfaceTransformOps.td        |   24 +-
 8 files changed, 1630 insertions(+), 251 deletions(-)
 create mode 100644 mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index cd033c140a233..8bdf3e0b566ef 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [
                                     getNumDynamicControlOperands() + getRank());
     }
 
+    BlockArgument getTiedBlockArgument(OpResult opResult) {
+      assert(opResult.getDefiningOp() == getOperation()  && "invalid OpResult");
+      return getBody()->getArgument(getRank() + opResult.getResultNumber());
+    }
+
     ::mlir::Value getInductionVar(int64_t idx) {
       return getInductionVars()[idx];
     }
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 7c735d825b445..0005fad3d5c01 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
 /// tiled in a manner that is consistent for all the passed slices. Note that
 /// the method replaces the uses of `candidateSlices` with the tiled and fused
 /// consumer value but does not delete the slice operations.
+/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion
+/// is to take the consumer operation, and find the slices to use for fusion
+/// by walking its operands to the `loops` and then into the body to get the
+/// slices used for fusion.
 struct SCFFuseConsumerOfSliceResult {
   // Original untiled consumer operands.
   SmallVector<OpOperand *> origConsumerOperands;
@@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
                             ArrayRef<Operation *> candidateSlices,
                             MutableArrayRef<LoopLikeOpInterface> loops);
 
+/// Fuse the `consumer` operation into the loop nest provided by `loops`.
+/// The transformation looks for operands in the `consumer` that are defined
+/// by the outermost loop of the loop nest in `loops`. The nested loop is
+/// expected to have the structure of the loops generated through tiling.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
+                    MutableArrayRef<LoopLikeOpInterface> loops);
+
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
 FailureOr<SmallVector<scf::ForOp>>
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 29b770fb4b279..7e715ee189740 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
   for (auto [outerLoop, innerLoop] :
        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
     // Again assume that all the outer loops are scf.for operations.
-    auto outerForLoop = cast<scf::ForOp>(outerLoop);
+    auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
     auto outerLoopYield =
         cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
     SmallVector<Value> newYields =
@@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
   return clonedSlices;
 }
 
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place for scf loop.
-FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlices(
-    RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
-    MutableArrayRef<LoopLikeOpInterface> loops) {
-  if (candidateSlices.empty()) {
-    return rewriter.notifyMatchFailure(
-        rewriter.getUnknownLoc(),
-        "no candidate slices provided for consumer fusion");
-  }
-  // Return if `loops` is empty, return an error for now. Caller is expected
-  // to handle this case.
-  if (loops.empty()) {
-    return rewriter.notifyMatchFailure(
-        candidateSlices.front(),
-        "cannot call tile and fuse consumer with an empty loop nest");
-  }
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp,
+                                ArrayRef<OpOperand *> consumerOpOperands,
+                                ArrayRef<Operation *> candidateSlices,
+                                MutableArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "expected loops to be not empty");
 
-  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
-        llvm::all_of(candidateSlices,
-                     llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+  // 1. Check assumption for loop with `reorderOperations` disabled.
+  if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) {
     return rewriter.notifyMatchFailure(
-        candidateSlices.front(),
-        "candidates slices need to be all `tensor.extract_slice`s or "
-        "`tensor.parallel_insert_slice`s");
-  }
-
-  // 1. Get the consumer of scf.for for the result yielded by
-  // tensor.insert_slice/parallel_insert_slice.
-  SmallVector<OpOperand *> consumerOpOperands;
-  Operation *consumerOp;
-  {
-    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
-        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
-    if (failed(maybeConsumerOpOperand)) {
-      return rewriter.notifyMatchFailure(candidateSlices.front(),
-                                         "could not fetch consumer to fuse");
-    }
-    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
-    consumerOp = consumerOpOperands.front()->getOwner();
+        loops.front(), "the first user of loop should not dominate any define "
+                       "of consumer operand(s)");
   }
 
   LoopLikeOpInterface outerMostLoop = loops.front();
   LoopLikeOpInterface innerMostLoop = loops.back();
 
-  // Check assumption for loop with `reorderOperations` disabled.
-  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
-    return rewriter.notifyMatchFailure(
-        outerMostLoop, "the first user of loop should not dominate any define "
-                       "of consumer operand(s)");
-  }
-
   OpBuilder::InsertionGuard g(rewriter);
-
   // 2. Check consumer is not using scf loop's output as init.
   auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
   if (!dstOp)
@@ -2428,11 +2391,171 @@ mlir::scf::tileAndFuseConsumerOfSlices(
       llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
         return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
       });
+  auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
   return scf::SCFFuseConsumerOfSliceResult{
-      std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
+      std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
       std::move(tileAndFuseResult->tiledOps)};
 }
 
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlices(
+    RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
+  if (candidateSlices.empty()) {
+    return rewriter.notifyMatchFailure(
+        rewriter.getUnknownLoc(),
+        "no candidate slices provided for consumer fusion");
+  }
+  // Return if `loops` is empty, return an error for now. Caller is expected
+  // to handle this case.
+  if (loops.empty()) {
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
+        "cannot call tile and fuse consumer with an empty loop nest");
+  }
+
+  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
+        llvm::all_of(candidateSlices,
+                     llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
+        "candidates slices need to be all `tensor.extract_slice`s or "
+        "`tensor.parallel_insert_slice`s");
+  }
+
+  // Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice/parallel_insert_slice.
+  SmallVector<OpOperand *> consumerOpOperands;
+  Operation *consumerOp;
+  {
+    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
+        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+    if (failed(maybeConsumerOpOperand)) {
+      return rewriter.notifyMatchFailure(candidateSlices.front(),
+                                         "could not fetch consumer to fuse");
+    }
+    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
+    consumerOp = consumerOpOperands.front()->getOwner();
+  }
+
+  return tileAndFuseConsumerOfSlicesImpl(
+      rewriter, consumerOp, consumerOpOperands, candidateSlices, loops);
+}
+
+/// For a given `result` of a `forallOp` return the
+/// `tensor.parallel_insert_slice` op (or combining op) that is used to
+/// construct this result.
+static std::optional<Operation *>
+getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
+  if (result.getOwner() != forallOp)
+    return std::nullopt;
+  BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
+  SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
+  // If the number of combining ops is not 1, then this is unexpected. Return
+  // nullopt.
+  if (combiningOps.size() != 1) {
+    return std::nullopt;
+  }
+  return combiningOps[0];
+}
+
+/// For a given result of the loop nest that is a tiled loop nest, return the
+/// insert slice-like op that is used for consumer fusion
+std::optional<Operation *>
+getProducingInsertSliceLikeOp(OpResult result,
+                              ArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "Expected loops to be not empty");
+  LoopLikeOpInterface outermostLoop = loops.front();
+
+  if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) {
+    assert(loops.size() == 1 &&
+           "expected only a single loop when tiling using scf.forall");
+    return getProducingParallelInsertSlice(forallOp, result);
+  }
+  // Assume that the loop nest is a nested `scf.for` that is created through
+  // tiling and retrieve the `tensor.insert_slice` operation used to construct
+  // the result.
+  while (loops.size() != 1) {
+    if (result.getOwner() != loops.front())
+      return std::nullopt;
+    auto forOp = dyn_cast<scf::ForOp>(loops.front());
+    if (!forOp)
+      return std::nullopt;
+    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+    OpResult innerForResult =
+        dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
+    if (!innerForResult)
+      return std::nullopt;
+    result = innerForResult;
+    loops = loops.drop_front();
+  }
+  if (result.getOwner() != loops.front())
+    return std::nullopt;
+  auto forOp = dyn_cast<scf::ForOp>(loops.front());
+  if (!forOp)
+    return std::nullopt;
+  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+  auto insertSliceOp = yieldOp.getOperand(result.getResultNumber())
+                           .getDefiningOp<tensor::InsertSliceOp>();
+  if (!insertSliceOp)
+    return std::nullopt;
+  return insertSliceOp;
+}
+
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user,
+                               MutableArrayRef<LoopLikeOpInterface> loops) {
+  // Only handle users that implement the `TilingInterface`.
+  if (!isa<TilingInterface>(user)) {
+    return rewriter.notifyMatchFailure(
+        user, "unhandled user that does not implement TilingInterface");
+  }
+
+  // Return if `loops` is empty, return an error for now. Caller is expected
+  // to handle this case.
+  if (loops.empty()) {
+    return rewriter.notifyMatchFailure(
+        user, "cannot call tile and fuse consumer with an empty loop nest");
+  }
+
+  LoopLikeOpInterface outermostLoop = loops.front();
+
+  // Collect the operands of the user that come from the outermost loop of the
+  // loop nest.
+  SmallVector<OpOperand *> consumerFusableOperands;
+  for (OpOperand &opOperand : user->getOpOperands()) {
+    if (opOperand.get().getDefiningOp() == outermostLoop) {
+      consumerFusableOperands.push_back(&opOperand);
+    }
+  }
+
+  // Nothing to fuse. Just return an empty set.
+  if (consumerFusableOperands.empty()) {
+    return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
+                                                   SmallVector<OpOperand *>{},
+                                                   SmallVector<Operation *>{}};
+  }
+
+  // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
+  // for fusion.
+  SmallVector<Operation *> candidateSlices;
+  candidateSlices.reserve(consumerFusableOperands.size());
+  for (OpOperand *opOperand : consumerFusableOperands) {
+    std::optional<Operation *> slice =
+        getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
+    if (!slice) {
+      return rewriter.notifyMatchFailure(
+          user,
+          "couldnt find producing insert-slice like operation for operand");
+    }
+    candidateSlices.push_back(slice.value());
+  }
+  return tileAndFuseConsumerOfSlicesImpl(
+      rewriter, user, consumerFusableOperands, candidateSlices, loops);
+}
+
 //===----------------------------------------------------------------------===//
 // lowerToLoopsUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
index 185fb9b358055..d72ab080f3c5c 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -170,7 +170,7 @@ module {
       // Fuse the consumer operation into the tiled loop.
       %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
           : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
-      transform.test.fuse_consumer %slice_op in (%forall_op)
+      transform.test.fuse_consumer_using_slice %slice_op in (%forall_op)
         : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
       transform.yield
     }
@@ -231,7 +231,7 @@ module {
       // Fuse the consumer operation into the tiled loop.
       %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
           : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
-      // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
+      // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice
       // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
       // to fuse" error.
       transform.yield
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
new file mode 100644
index 0000000000000..62dd7faec4eb7
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
@@ -0,0 +1,1156 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.addf %out, %13 : f32
+          linalg.yield %14 : f32
+        } -> tensor<32xf32>
+      %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+    }
+    %in_operand_2 = tensor.empty() : tensor<64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64xf32>
+    %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
+    return %2 : tensor<64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_for(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %0 = tensor.empty() : tensor<64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME:      iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[ELEM_OUT:.*]] = linalg.add
+// CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+module {
+  func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+         tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+      }
+    }
+    %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+    %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
+    return %2 : tensor<64x64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+        : (!transform.any_op)
+        -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer_using_slice %first_slice_op in (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_forall(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] :
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[ELEM_OUT:.*]] = linalg.add
+// CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.addf %out, %13 : f32
+          linalg.yield %14 : f32
+        } -> tensor<32xf32>
+      %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+    }
+    %in_operand_2 = tensor.empty() : tensor<64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64xf32>
+    %out_operand_4 = tensor.empty() : tensor<64xf32>
+    %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) {
+      ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.subf %out_0, %13 : f32
+          %15 = arith.addf %out_1, %in : f32
+          linalg.yield %14, %15 : f32, f32
+    } -> (tensor<64xf32>, tensor<64xf32>)
+    return %2#1 : tensor<64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %0 = tensor.empty() : tensor<64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME:      iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0)
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+//      CHECK:      %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+//      CHECK:      %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+//      CHECK:      %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+//      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#3 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+        tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+      }
+    }
+    %1 = tensor.empty() : tensor<64x64xf32>
+    %2 = tensor.empty() : tensor<64x64xf32>
+    %3 = tensor.empty() : tensor<64x64xf32>
+    %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
+      %6 = arith.mulf %in, %in_0 : f32
+      %7 = arith.subf %out, %6 : f32
+      %8 = arith.addf %out_1, %in : f32
+      linalg.yield %7, %8 : f32, f32
+    } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+    %5 = tensor.empty() : tensor<2048xf32>
+    %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
+    return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+        : (!transform.any_op)
+        -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer_using_slice %first_slice_op in (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG3]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] :
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   %[[UNPACK:.*]] = linalg.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32>
+//      CHECK:   return %[[FINAL_RESULT]]#3, %[[UNPACK]] :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+        %13 = arith.mulf %in, %in_16 : f32
+        %14 = arith.addf %out, %13 : f32
+        linalg.yield %14 : f32
+      } -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+      }
+    }
+    %output = tensor.empty() : tensor<2048xf32>
+    %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
+    return %unpack : tensor<2048xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop)
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//  CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
+//  CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
+//      CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
+//  CHECK-DAG:      %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
+//  CHECK-DAG:      %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
+//      CHECK:      %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
+//      CHECK:      %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]]
+// CHECK-SAME:                              outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME:                              into %[[TILED_UNPACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+        %13 = arith.mulf %in, %in_16 : f32
+        %14 = arith.addf %out, %13 : f32
+        linalg.yield %14 : f32
+      } -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+      }
+    }
+    %output = tensor.empty() : tensor<2047xf32>
+    %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
+    return %unpack : tensor<2047xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop)
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//  CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
+//  CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
+//      CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
+//  CHECK-DAG:      %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
+//  CHECK-DAG:      %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
+//      CHECK:      %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
+//      CHECK:      %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]]
+// CHECK-SAME:                              outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME:                              into %[[TILED_UNPACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  func.func @fuse_perfect_tiling_pack_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+        %13 = arith.mulf %in, %in_16 : f32
+        %14 = arith.addf %out, %13 : f32
+        linalg.yield %14 : f32
+      } -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+      }
+    }
+    %output = tensor.empty() : tensor<4x32x16xf32>
+    %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
+    return %pack : tensor<4x32x16xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop)
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+//      CHECK: func.func @fuse_perfect_tiling_pack_consumer(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 1)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
+//      CHECK:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
+//      CHECK:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
+//      CHECK:      %[[TILED_PACK_OUT:.*]] = linalg.pack %[[GENERIC_OUT]]
+// CHECK-SAME:                              inner_dims_pos = [0] inner_tiles = [16]
+// CHECK-SAME:                              into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]],  %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
+
+// -----
+
+#map = affine_map<(d0) -> (-d0 + 4, 16)>
+func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> tensor<1x4x16x1xf32> {
+  %0 = tensor.empty() : tensor<1x4x16x1xf32>
+  %1 = tensor.empty() : tensor<4x4xf32>
+  %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) {
+    %3 = affine.min #map(%arg1)
+    %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+    %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+    %4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32>
+    }
+  }
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %0 : tensor<4x4xf32> -> tensor<1x4x16x1xf32>
+  return %pack : tensor<1x4x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)>
+//      CHECK: func.func @fuse_pack_consumer_if_single_iteration(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32>
+//  CHECK-DAG:   %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
+//  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
+// CHECK-SAME:      shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
+//  CHECK-DAG:      %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
+//  CHECK-DAG:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//  CHECK-DAG:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:        outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
+
+// -----
+
+func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
+  %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+    %src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %dest = tensor.extract_slice %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %1 into %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+    }
+  }
+  %pack = linalg.pack %0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] into %arg2 : tensor<64x32xf32> -> tensor<2x64x16x1xf32>
+  return %pack : tensor<2x64x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+//      CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
+
+// -----
+
+// It is valid to fuse the pack op in perfect tiling scenario when the dimension
+// is dynamic and padding is not needed.
+
+func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, %arg1: tensor<64x?xf32>, %1: tensor<64x?x16xf32>) -> tensor<64x?x16xf32> {
+  %c1 = arith.constant 1 : index
+  %d1 = tensor.dim %arg0, %c1 : tensor<64x?xf32>
+  %0 = scf.forall (%arg2) = (0) to (%d1) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x?xf32>) {
+    %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32>
+    %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32>
+    %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x?xf32>
+    }
+  }
+  %pack = linalg.pack %0 inner_dims_pos = [1] inner_tiles = [16] into %1 : tensor<64x?xf32> -> tensor<64x?x16xf32>
+  return %pack : tensor<64x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+//      CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (%{{.+}}) step (16)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        inner_dims_pos = [1] inner_tiles = [16]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1]
+
+// -----
+
+// It is valid to fuse the pack op with padding semantics if it is a perfect
+// tiling case.
+
+func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> {
+  %0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+    %size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2)
+    %src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32>
+    %dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32>
+    %2 = linalg.exp ins(%src : tensor<?x16xf32>) outs(%dest : tensor<?x16xf32>) -> tensor<?x16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<?x16xf32> into tensor<64x32xf32>
+    }
+  }
+  %1 = tensor.empty() : tensor<22x2x3x16xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<22x2x3x16xf32>
+  return %pack : tensor<22x2x3x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//  CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)>
+//  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)>
+//  CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)>
+//  CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+//      CHECK: func.func @fuse_pack_consumer_with_padding_semantics(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32>
+//  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16)
+// CHECK-SAME:      shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]])
+//      CHECK:      %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])
+//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:        [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
+//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]]
+// CHECK-SAME:        [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]])
+//  CHECK-DAG:      %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]])
+//  CHECK-DAG:      %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]])
+//  CHECK-DAG:      %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]]
+// CHECK-SAME:        [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [3, 16]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]]
+// CHECK-SAME:            [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]]
+// CHECK-SAME:            [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1]
+
+// -----
+
+// Imperfect tiling is not supported in pack op consumer fusion.
+
+#map = affine_map<(d0) -> (d0 * 5)>
+#map1 = affine_map<(d0) -> (d0)>
+func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x6xf32> {
+  %0 = tensor.empty() : tensor<30xf32>
+  %1 = scf.forall (%arg1) in (6) shared_outs(%arg2 = %0) -> (tensor<30xf32>) {
+    %3 = affine.apply #map(%arg1)
+    %extracted_slice = tensor.extract_slice %arg0[%3] [5] [1] : tensor<30xf32> to tensor<5xf32>
+    %extracted_slice_0 = tensor.extract_slice %arg2[%3] [5] [1] : tensor<30xf32> to tensor<5xf32>
+    %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<5xf32>) outs(%extracted_slice_0 : tensor<5xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %5 = arith.addf %in, %in : f32
+      linalg.yield %5 : f32
+    } -> tensor<5xf32>
+    scf.forall.in_parallel {
+      // expected-error @below {{failed to fuse consumer of slice}}
+      tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32>
+    }
+  }
+  %2 = tensor.empty() : tensor<5x6xf32>
+  %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32>
+  return %pack : tensor<5x6xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+module {
+  func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c256 = arith.constant 256 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+        %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+        %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+        scf.yield %insert_slice : tensor<256x256xf32>
+    }
+    %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 2
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_add_multiple_tilable_consumers(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
+// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
+// CHECK-SAME:   {
+//      CHECK:          %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME:                ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
+// CHECK-SAME:                outs(%[[ADD_OUT_SLICE]] :
+//      CHECK:          %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[TILED_EXP_OUT:.*]] = linalg.exp
+// CHECK-SAME:                ins(%[[TILED_ADD_OUT]] :
+// CHECK-SAME:                outs(%[[EXP_OUT_SLICE]] :
+//      CHECK:          %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[TILED_MUL_OUT:.*]] = linalg.mul
+// CHECK-SAME:                ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
+// CHECK-SAME:                outs(%[[MUL_OUT_SLICE]] :
+//      CHECK:          %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
+//      CHECK:   }
+//      CHECK:   return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
+
+// -----
+
+module {
+  func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c256 = arith.constant 256 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+        %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+        %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+        scf.yield %insert_slice : tensor<256x256xf32>
+    }
+    %dest1 = tensor.empty() : tensor<258x258xf32>
+    %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
+    %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 1
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @no_fuse_only_dps_consumer(
+//      CHECK:   %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
+//      CHECK:     linalg.add
+//      CHECK:     linalg.mul
+//      CHECK:     scf.yield
+//      CHECK:   }
+//      CHECK:   %[[RES_SLICE:.+]] = tensor.insert_slice
+//      CHECK:   return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+module {
+  func.func @fuse_with_tilable_consumer_with_projected_permutations(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<24xf32>) -> tensor<256x256x24xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c256 = arith.constant 256 : index
+    %0 = tensor.empty() : tensor<256x256xf32>
+    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %0) -> (tensor<256x256xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+      %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+      %4 = linalg.add ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice : tensor<64x256xf32>) -> tensor<64x256xf32>
+      %inserted_slice = tensor.insert_slice %4 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+      scf.yield %inserted_slice : tensor<256x256xf32>
+    }
+    %2 = tensor.empty() : tensor<256x256x24xf32>
+    %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg2 : tensor<256x256xf32>, tensor<24xf32>) outs(%2 : tensor<256x256x24xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %4 = arith.addf %in, %in_0 : f32
+      linalg.yield %4 : f32
+    } -> tensor<256x256x24xf32>
+    return %3 : tensor<256x256x24xf32>
+  }
+}
+
+// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:             %[[VAL_4:.*]] = arith.constant 64 : index
+// CHECK:             %[[VAL_5:.*]] = arith.constant 256 : index
+// CHECK:             %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK:             %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32>
+// CHECK:             %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) {
+// CHECK:               %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32>
+// CHECK:               %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32>
+// CHECK:               %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+// CHECK:               %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) {
+// CHECK:               ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+// CHECK:                 %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK:                 linalg.yield %[[VAL_23]] : f32
+// CHECK:               } -> tensor<64x256x24xf32>
+// CHECK:               %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+// CHECK:               scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32>
+// CHECK:             }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 1
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %loop:2 = scf.forall (%iv0) =  (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+    %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic:2 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	%1 = arith.addf %b0, %b2 : f32
+	linalg.yield %0, %1 : f32, f32
+    } -> (tensor<?xf32>, tensor<?xf32>)
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+      tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+    }
+  }
+  %empty = tensor.empty(%dim0) : tensor<?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?xf32>
+  return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion1(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0
+//       CHECK:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//       CHECK:     %[[TILESIZE:.+]] = affine.min
+//   CHECK-DAG:     %[[GENERIC:.+]]:2 = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC]]#0, %[[GENERIC]]#1 :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:   return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// Check that when the given operand tiles are inconsistent, tiling fails.
+
+func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %loop:2 = scf.forall (%iv0) =  (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+    %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	linalg.yield %0 : f32
+    } -> tensor<?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+	%0 = arith.addf %b0, %b1 : f32
+	linalg.yield %0: f32
+    } -> tensor<?xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+    }
+  }
+  %empty = tensor.empty(%dim0) : tensor<?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?xf32>
+  return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion2(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0
+//       CHECK:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//       CHECK:     %[[TILESIZE:.+]] = affine.min
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//       CHECK:     %[[GENERIC1:.+]] = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC0]], %[[GENERIC1]] :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:   return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
+    %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %loop:2 = scf.forall (%iv0, %iv1) =  (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+      shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) {
+    %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+        : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	      iterator_types = ["parallel", "reduction"]}
+	      ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+      	%0 = arith.addf %b0, %b1 : f32
+	      linalg.yield %0: f32
+    } -> tensor<?xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32>
+    }
+  }
+  %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @multi_slice_fusion_with_broadcast(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1
+//   CHECK-DAG:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//   CHECK-DAG:     %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]])
+//   CHECK-DAG:     %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]])
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//       CHECK:     %[[GENERIC1:.+]] = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC0]], %[[GENERIC1]] :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
+//       CHECK:   return %[[RESULT]]#2
+
+// -----
+
+func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
+    %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %loop:2 = scf.forall (%iv0, %iv1) =  (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+      shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+    %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+        : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+      	%0 = arith.addf %b0, %b1 : f32
+	      linalg.yield %0: f32
+    } -> tensor<?x?xf32>
+    scf.forall.in_parallel {
+      // expected-error @below {{failed to fuse consumer of slice}}
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+    }
+  }
+  %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 78884625ce7dc..0137e2a69a46e 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -1,8 +1,8 @@
-// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics --mlir-print-local-scope %s | FileCheck %s
 
 #map = affine_map<(d0) -> (d0)>
 module {
-  func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+  func.func @fuse_tilable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
     %c4 = arith.constant 4 : index
     %c64 = arith.constant 64 : index
     %c0 = arith.constant 0 : index
@@ -28,14 +28,14 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %loop = transform.structured.match ops{["scf.for"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+    %add = transform.structured.match ops{["linalg.add"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %yield in (%loop)
+    %a, %new_loop = transform.test.fuse_consumer %add into (%loop)
       : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//      CHECK: func.func @fuse_tileable_consumer_scf_for(
+//      CHECK: func.func @fuse_tilable_consumer_scf_for(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
@@ -60,8 +60,61 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+#map = affine_map<(d0) -> (d0)>
 module {
-  func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  func.func @fuse_tilable_consumer_nested_scf_for(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
+      %lb0 : index, %ub0 : index, %step0 : index,
+      %lb1 : index, %ub1 : index, %step1 : index) -> tensor<?x?xf32> {
+    %0 = scf.for %arg3 = %lb0 to %ub0 step %step0 iter_args(%init0 = %arg0) -> tensor<?x?xf32> {
+      %1 = scf.for %arg4 = %lb1 to %ub1 step %step1 iter_args(%init1 = %init0) -> tensor<?x?xf32> {
+        %extracted_slice = tensor.extract_slice %init1[%arg3, %arg4] [%step0, %step1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+        %2 = tensor.insert_slice %extracted_slice into %init1[%arg3, %arg4] [%step0, %step1] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+        scf.yield %2 : tensor<?x?xf32>
+      }
+      scf.yield %1 : tensor<?x?xf32>
+    }
+    %2 = linalg.add ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+    return %2 : tensor<?x?xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loops = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %loop0, %loop1 = transform.split_handle %loops
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %add = transform.structured.match ops{["linalg.add"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %new_loop0, %new_loop1 = transform.test.fuse_consumer %add into (%loop0, %loop1)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func @fuse_tilable_consumer_nested_scf_for(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK: %[[OUTER_RESULT:.+]]:2 = scf.for
+// CHECK-SAME:     iter_args(%[[INIT00:[a-zA-Z0-9_]+]] = %[[ARG0]], %[[INIT01:[a-zA-Z0-9_]+]] = %[[ARG2]])
+//      CHECK:   %[[INNER_RESULT:.+]]:2 = scf.for
+// CHECK-SAME:       iter_args(%[[INIT10:[a-zA-Z0-9_]+]] = %[[INIT00]], %[[INIT11:[a-zA-Z0-9_]+]] = %[[INIT01]])
+//  CHECK-DAG:     %[[OPERAND1:.+]] = tensor.extract_slice %[[INIT10]]
+//  CHECK-DAG:     %[[OLD_INSERT_SLICE:.+]] = tensor.insert_slice %[[OPERAND1]] into %[[INIT10]]
+//  CHECK-DAG:     %[[OPERAND2:.+]] = tensor.extract_slice %[[ARG1]]
+//  CHECK-DAG:     %[[INIT:.+]] = tensor.extract_slice %[[INIT11]]
+//      CHECK:     %[[ADD:.+]] = linalg.add
+// CHECK-SAME:         ins(%[[OPERAND1]], %[[OPERAND2]] :
+// CHECK-SAME:         outs(%[[INIT]] :
+//      CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[ADD]] into %[[INIT11]]
+//      CHECK:     scf.yield %[[OLD_INSERT_SLICE]], %[[INSERT_SLICE]]
+//      CHECK:   scf.yield %[[INNER_RESULT]]#0, %[[INNER_RESULT]]#1
+//      CHECK: return %[[OUTER_RESULT]]#1
+
+// -----
+
+module {
+  func.func @fuse_tilable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
     %c4 = arith.constant 4 : index
     %c64 = arith.constant 64 : index
     %c0 = arith.constant 0 : index
@@ -83,19 +136,16 @@ module {
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    %add = transform.structured.match ops{["linalg.add"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
     %loop = transform.structured.match ops{["scf.forall"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
-        : (!transform.any_op)
-        -> (!transform.any_op, !transform.any_op)
-    %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop)
+    %a, %new_loop = transform.test.fuse_consumer %add into (%loop)
       : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//      CHECK: func.func @fuse_tileable_consumer_scf_forall(
+//      CHECK: func.func @fuse_tilable_consumer_scf_forall(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
@@ -124,7 +174,7 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0) -> (d0)>
 module {
-  func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+  func.func @fuse_tilable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
     %c4 = arith.constant 4 : index
     %c64 = arith.constant 64 : index
     %c0 = arith.constant 0 : index
@@ -155,16 +205,18 @@ module {
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+    %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
+    %producer, %consumer = transform.split_handle %generics
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %loop = transform.structured.match ops{["scf.for"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %yield in (%loop)
+    %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop)
       : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//      CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(
+//      CHECK: func.func @fuse_tilable_consumer_scf_for_multi_yielding_consumer(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
@@ -193,7 +245,7 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-  func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
+  func.func @fuse_tilable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
     %c4 = arith.constant 4 : index
     %c64 = arith.constant 64 : index
     %c0 = arith.constant 0 : index
@@ -224,19 +276,16 @@ module {
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
     %loop = transform.structured.match ops{["scf.forall"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
-        : (!transform.any_op)
-        -> (!transform.any_op, !transform.any_op)
-    %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop)
+    %a, %new_loops = transform.test.fuse_consumer %generic into (%loop)
       : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//      CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
+//      CHECK: func.func @fuse_tilable_consumer_scf_forall_multi_yielding_consumer(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>
@@ -293,17 +342,15 @@ module {
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    %consumer = transform.structured.match ops{["linalg.unpack"]} in %arg1
     : (!transform.any_op) -> !transform.any_op
     %loop = transform.structured.match ops{["scf.forall"]} in %arg1
     : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+    %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop)
     : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//  CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
-//  CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
 //      CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
@@ -315,8 +362,8 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
 // CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
-//  CHECK-DAG:      %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
-//  CHECK-DAG:      %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
+//  CHECK-DAG:      %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 * 32)>(%[[IV1]])
+//  CHECK-DAG:      %[[UNPACK_RESULT_SIZE:.*]] = affine.min affine_map<(d0) -> (1024, d0 * -32 + 2048)>(%[[IV1]])
 //      CHECK:      %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
 //      CHECK:      %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]]
 // CHECK-SAME:                              outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
@@ -356,17 +403,15 @@ module {
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    %consumer = transform.structured.match ops{["linalg.unpack"]} in %arg1
     : (!transform.any_op) -> !transform.any_op
     %loop = transform.structured.match ops{["scf.forall"]} in %arg1
     : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+    %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop)
     : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//  CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
-//  CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
 //      CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
@@ -378,8 +423,8 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
 // CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
-//  CHECK-DAG:      %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
-//  CHECK-DAG:      %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
+//  CHECK-DAG:      %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 * 32)>(%[[IV1]])
+//  CHECK-DAG:      %[[UNPACK_RESULT_SIZE:.*]] = affine.min affine_map<(d0) -> (1024, d0 * -32 + 2047)>(%[[IV1]])
 //      CHECK:      %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
 //      CHECK:      %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]]
 // CHECK-SAME:                              outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
@@ -419,16 +464,15 @@ module {
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    %consumer = transform.structured.match ops{["linalg.pack"]} in %arg1
     : (!transform.any_op) -> !transform.any_op
     %loop = transform.structured.match ops{["scf.forall"]} in %arg1
     : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+    %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop)
     : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
 //      CHECK: func.func @fuse_perfect_tiling_pack_consumer(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
@@ -440,7 +484,7 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
 // CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
-//      CHECK:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
+//      CHECK:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV1]])
 //      CHECK:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
 //      CHECK:      %[[TILED_PACK_OUT:.*]] = linalg.pack %[[GENERIC_OUT]]
 // CHECK-SAME:                              inner_dims_pos = [0] inner_tiles = [16]
@@ -471,13 +515,12 @@ func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> ten
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused_consumer, %new_loop = transform.test.fuse_consumer %consumer into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//      CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)>
 //      CHECK: func.func @fuse_pack_consumer_if_single_iteration(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 //  CHECK-DAG:   %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32>
@@ -485,7 +528,7 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
 //      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
 // CHECK-SAME:      shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
-//  CHECK-DAG:      %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
+//  CHECK-DAG:      %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 4, 16)>(%[[IV]])
 //  CHECK-DAG:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
 //  CHECK-DAG:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
 //      CHECK:      %[[ELEM:.*]] = linalg.exp
@@ -517,13 +560,12 @@ func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
 //      CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
@@ -535,7 +577,7 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:      %[[ELEM:.*]] = linalg.exp
 // CHECK-SAME:        ins(%[[ELEM_SRC]]
 // CHECK-SAME:        outs(%[[ELEM_DEST]]
-//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV]])
 //  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
 //      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
 // CHECK-SAME:        outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1]
@@ -566,13 +608,12 @@ func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, %
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
 //      CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
@@ -584,7 +625,7 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:      %[[ELEM:.*]] = linalg.exp
 // CHECK-SAME:        ins(%[[ELEM_SRC]]
 // CHECK-SAME:        outs(%[[ELEM_DEST]]
-//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV]])
 //  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1]
 //      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
 // CHECK-SAME:        inner_dims_pos = [1] inner_tiles = [16]
@@ -616,16 +657,12 @@ func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
-//  CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)>
-//  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)>
-//  CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)>
-//  CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
 //      CHECK: func.func @fuse_pack_consumer_with_padding_semantics(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
@@ -633,7 +670,7 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
 //      CHECK:   %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16)
 // CHECK-SAME:      shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]])
-//      CHECK:      %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])
+//      CHECK:      %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%[[I]])
 //      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]]
 // CHECK-SAME:        [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
 //      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]]
@@ -641,9 +678,9 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:      %[[ELEM:.*]] = linalg.exp
 // CHECK-SAME:        ins(%[[ELEM_SRC]]
 // CHECK-SAME:        outs(%[[ELEM_DEST]]
-//  CHECK-DAG:      %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]])
-//  CHECK-DAG:      %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]])
-//  CHECK-DAG:      %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]])
+//  CHECK-DAG:      %[[D0_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 3)>(%[[I]])
+//  CHECK-DAG:      %[[D0_SIZE:.*]] = affine.apply affine_map<(d0) -> (d0 ceildiv 3)>(%[[SIZE]])
+//  CHECK-DAG:      %[[D1_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[J]])
 //  CHECK-DAG:      %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]]
 // CHECK-SAME:        [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1]
 //      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
@@ -674,20 +711,21 @@ func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x
       linalg.yield %5 : f32
     } -> tensor<5xf32>
     scf.forall.in_parallel {
-      // expected-error @below {{failed to fuse consumer of slice}}
+      
       tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32>
     }
   }
   %2 = tensor.empty() : tensor<5x6xf32>
+  // expected-error @below {{failed to fuse consumer of slice}}
   %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32>
   return %pack : tensor<5x6xf32>
 }
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -717,11 +755,15 @@ module {
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+    %mulop = transform.structured.match ops{["linalg.mul"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
     %loop = transform.structured.match ops{["scf.for"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 2
+    %fused_consumer, %new_loop = transform.test.fuse_consumer %mulop into (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %expop = transform.structured.match ops{["linalg.exp"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %fused_consumer_2, %new_loop_2 = transform.test.fuse_consumer %expop into (%new_loop)
       : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
@@ -741,64 +783,20 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:                ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
 // CHECK-SAME:                outs(%[[ADD_OUT_SLICE]] :
 //      CHECK:          %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
-//      CHECK:          %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
-//      CHECK:          %[[TILED_EXP_OUT:.*]] = linalg.exp
-// CHECK-SAME:                ins(%[[TILED_ADD_OUT]] :
-// CHECK-SAME:                outs(%[[EXP_OUT_SLICE]] :
 //      CHECK:          %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
-//      CHECK:          %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
 //      CHECK:          %[[TILED_MUL_OUT:.*]] = linalg.mul
 // CHECK-SAME:                ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
 // CHECK-SAME:                outs(%[[MUL_OUT_SLICE]] :
-//      CHECK:          %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
-//      CHECK:          %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
-//      CHECK:          scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
-//      CHECK:   }
-//      CHECK:   return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
-
-// -----
-
-module {
-  func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
-    %c0 = arith.constant 0 : index
-    %c64 = arith.constant 64 : index
-    %c256 = arith.constant 256 : index
-    %cst = arith.constant 0.000000e+00 : f32
-    %dest0 = tensor.empty() : tensor<256x256xf32>
-    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
-        %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
-        %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
-        %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
-        %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
-        %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
-        scf.yield %insert_slice : tensor<256x256xf32>
-    }
-    %dest1 = tensor.empty() : tensor<258x258xf32>
-    %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
-    %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
-    return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
-  }
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1
-      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-//      CHECK: func.func @no_fuse_only_dps_consumer(
-//      CHECK:   %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
-//      CHECK:     linalg.add
-//      CHECK:     linalg.mul
-//      CHECK:     scf.yield
+//      CHECK:          %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[TILED_EXP_OUT:.*]] = linalg.exp
+// CHECK-SAME:                ins(%[[TILED_ADD_OUT]] :
+// CHECK-SAME:                outs(%[[EXP_OUT_SLICE]] :
+//      CHECK:          %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          scf.yield %[[INSERT_ADD]], %[[INSERT_MUL]], %[[INSERT_EXP]] :
 //      CHECK:   }
-//      CHECK:   %[[RES_SLICE:.+]] = tensor.insert_slice
-//      CHECK:   return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
+//      CHECK:   return %[[LOOP_RESULT]]#1, %[[LOOP_RESULT]]#2 :
 
 // -----
 
@@ -829,40 +827,41 @@ module {
   }
 }
 
-// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> {
-// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK:             %[[VAL_4:.*]] = arith.constant 64 : index
-// CHECK:             %[[VAL_5:.*]] = arith.constant 256 : index
-// CHECK:             %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32>
-// CHECK:             %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32>
-// CHECK:             %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) {
-// CHECK:               %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
-// CHECK:               %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
-// CHECK:               %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
-// CHECK:               %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32>
-// CHECK:               %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
-// CHECK:               %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32>
-// CHECK:               %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
-// CHECK:               %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) {
-// CHECK:               ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
-// CHECK:                 %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
-// CHECK:                 linalg.yield %[[VAL_23]] : f32
-// CHECK:               } -> tensor<64x256x24xf32>
-// CHECK:               %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
-// CHECK:               scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32>
-// CHECK:             }
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+    %consumer = transform.structured.match ops{["linalg.generic"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
     %loop = transform.structured.match ops{["scf.for"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1
+    %a, %b = transform.test.fuse_consumer %consumer into (%loop)
       : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
+//      CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(
+// CHECK-SAME:     %[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> {
+//      CHECK:   %[[VAL_3:.*]] = arith.constant 0 : index
+//      CHECK:   %[[VAL_4:.*]] = arith.constant 64 : index
+//      CHECK:   %[[VAL_5:.*]] = arith.constant 256 : index
+//      CHECK:   %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32>
+//      CHECK:   %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) {
+//      CHECK:     %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+//      CHECK:     %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+//      CHECK:     %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+//      CHECK:     %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32>
+//      CHECK:     %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+//      CHECK:     %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32>
+//      CHECK:     %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+//      CHECK:     %[[VAL_19:.*]] = linalg.generic
+// CHECK-SAME:         ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) {
+//      CHECK:     ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+//      CHECK:       %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+//      CHECK:       linalg.yield %[[VAL_23]] : f32
+//      CHECK:     } -> tensor<64x256x24xf32>
+//      CHECK:     %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+//      CHECK:     scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32>
+//      CHECK:   }
 
 // -----
 
@@ -878,12 +877,12 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
     %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
     %generic:2 = linalg.generic {
         indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
-	iterator_types = ["parallel", "reduction"]}
-	ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
+      	iterator_types = ["parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
       ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
         %0 = arith.mulf %b0, %b1 : f32
-	%1 = arith.addf %b0, %b2 : f32
-	linalg.yield %0, %1 : f32, f32
+	      %1 = arith.addf %b0, %b2 : f32
+	      linalg.yield %0, %1 : f32, f32
     } -> (tensor<?xf32>, tensor<?xf32>)
     scf.forall.in_parallel {
       tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
@@ -901,6 +900,19 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
   } -> tensor<?xf32>
   return %result : tensor<?xf32>
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %producer, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %consumer into (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
 // CHECK-LABEL: func @multi_slice_fusion1(
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
 //       CHECK:   %[[C0:.+]] = arith.constant 0
@@ -916,23 +928,9 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
 //       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
 //       CHECK:   return %[[RESULT]]#2
 
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
-      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
 
 // -----
 
-// Check that when the given operand tiles are inconsistent, tiling fails.
-
 func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -944,20 +942,20 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
     %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
     %generic0 = linalg.generic {
         indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
-	iterator_types = ["parallel", "reduction"]}
-	ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
+      	iterator_types = ["parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
       ^bb0(%b0 : f32, %b1 : f32):
         %0 = arith.mulf %b0, %b1 : f32
-	linalg.yield %0 : f32
+	      linalg.yield %0 : f32
     } -> tensor<?xf32>
     %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
     %generic1 = linalg.generic {
         indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
-	iterator_types = ["parallel", "reduction"]}
-	ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+	      iterator_types = ["parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
       ^bb0(%b0 : f32, %b1 : f32):
-	%0 = arith.addf %b0, %b1 : f32
-	linalg.yield %0: f32
+	      %0 = arith.addf %b0, %b1 : f32
+       	linalg.yield %0: f32
     } -> tensor<?xf32>
     scf.forall.in_parallel {
       tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
@@ -975,6 +973,19 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
   } -> tensor<?xf32>
   return %result : tensor<?xf32>
 }
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %producer1, %producer2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %consumer into (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
 // CHECK-LABEL: func @multi_slice_fusion2(
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
 //       CHECK:   %[[C0:.+]] = arith.constant 0
@@ -991,19 +1002,6 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
 //       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
 //       CHECK:   return %[[RESULT]]#2
 
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
-      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-
 // -----
 
 func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
@@ -1060,11 +1058,11 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %loop = transform.structured.match ops{["scf.forall"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
-      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %producer_1, %producer_2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %consumer into (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -1124,7 +1122,6 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<
 	      linalg.yield %0: f32
     } -> tensor<?x?xf32>
     scf.forall.in_parallel {
-      // expected-error @below {{failed to fuse consumer of slice}}
       tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
           : tensor<?x?xf32> into tensor<?x?xf32>
       tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
@@ -1132,6 +1129,7 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<
     }
   }
   %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+  // expected-error @below {{failed to fuse consumer of slice}}
   %result = linalg.generic {
       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel"]}
@@ -1146,11 +1144,11 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %loop = transform.structured.match ops{["scf.forall"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
-      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %producer_1, %producer_2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %consumer into (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 326fec3ee5cf0..51dac0e866254 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -172,7 +172,71 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
 
 /// Apply fusing of consumer transformation to all payload ops and store both
 /// the original consumer operation as well as the fused consumer operation.
-static LogicalResult applyFuseConsumer(
+static LogicalResult
+applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
+                  Operation *consumer,
+                  MutableArrayRef<LoopLikeOpInterface> loops,
+                  TransformResults &transformResults) {
+  SmallVector<Operation *> fusedConsumerOps;
+
+  rewriter.setInsertionPoint(consumer);
+
+  FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+      scf::tileAndFuseConsumer(rewriter, consumer, loops);
+
+  if (failed(fuseConsumerResults))
+    return consumer->emitOpError("failed to fuse consumer of slice");
+
+  // Report back the relevant handles to the transform op.
+  for (OpOperand *tiledAndFusedConsumerOperand :
+       fuseConsumerResults->tiledAndFusedConsumerOperands) {
+    fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
+  }
+
+  transformResults.set(transformOp->getOpResult(0), fusedConsumerOps);
+  for (auto [index, loop] : llvm::enumerate(loops)) {
+    transformResults.set(transformOp->getOpResult(index + 1), {loop});
+  }
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
+                                     TransformResults &transformResults,
+                                     TransformState &state) {
+  Operation *consumer = *state.getPayloadOps(getConsumer()).begin();
+
+  SmallVector<LoopLikeOpInterface> loops;
+  // Since the matcher works inside-out, we need to iterate the loops in reverse.
+  for (auto loop : llvm::reverse(getLoops())) {
+    auto loopLikeOp =
+        dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(loop).begin());
+    if (!loopLikeOp) {
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    loops.push_back(loopLikeOp);
+  }
+  LogicalResult result = applyFuseConsumer(rewriter, getOperation(), consumer,
+                                           loops, transformResults);
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseConsumerOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getConsumerMutable(), effects);
+  consumesHandle(getLoopsMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
+// TestFuseConsumerUsingSliceOp
+//===----------------------------------------------------------------------===//
+
+/// Apply fusing of consumer transformation to all payload ops and store both
+/// the original consumer operation as well as the fused consumer operation.
+static LogicalResult applyFuseConsumerUsingSlices(
     RewriterBase &rewriter, Operation *transformOp,
     ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
     uint32_t numConsumerToFuse, TransformResults &transformResults) {
@@ -204,10 +268,9 @@ static LogicalResult applyFuseConsumer(
   return success();
 }
 
-DiagnosedSilenceableFailure
-transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
-                                     TransformResults &transformResults,
-                                     TransformState &state) {
+DiagnosedSilenceableFailure transform::TestFuseConsumerUsingSliceOp::apply(
+    TransformRewriter &rewriter, TransformResults &transformResults,
+    TransformState &state) {
   SmallVector<Operation *> slices;
   for (auto op : getTargets()) {
     auto sliceOp = *state.getPayloadOps(op).begin();
@@ -224,13 +287,13 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
     loops.push_back(loopLikeOp);
   }
   LogicalResult result =
-      applyFuseConsumer(rewriter, getOperation(), slices, loops,
-                        getNumConsumerToFuse(), transformResults);
+      applyFuseConsumerUsingSlices(rewriter, getOperation(), slices, loops,
+                                   getNumConsumerToFuse(), transformResults);
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
 }
 
-void transform::TestFuseConsumerOp::getEffects(
+void transform::TestFuseConsumerUsingSliceOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getTargetsMutable(), effects);
   consumesHandle(getLoopsMutable(), effects);
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 694c4229eef62..bfefad02418ac 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -49,7 +49,7 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
   }];
 }
 
-def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
+def TestFuseConsumerUsingSliceOp : Op<Transform_Dialect, "test.fuse_consumer_using_slice",
        [AttrSizedOperandSegments,
         DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -73,6 +73,28 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
   }];
 }
 
+def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Fuses the consumer of the operation pointed to by the target handle
+    using the options provided as attributes.
+  }];
+
+  let arguments = (ins 
+      TransformHandleTypeInterface:$consumer,
+      Variadic<TransformHandleTypeInterface>:$loops);
+  let results = (outs TransformHandleTypeInterface:$fused_consumer,
+      Variadic<TransformHandleTypeInterface>:$result_loops);
+
+  let assemblyFormat = [{
+    $consumer `into` `(` $loops `)`
+    attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+
 def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

>From 7e3749038a66585b06087a0eb5c2da221d75eeeb Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 12 Nov 2025 13:17:50 -0800
Subject: [PATCH 2/4] Fix warning (leading to build errors when warnings are
 treated as error)

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 7e715ee189740..03ce5555f56ff 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -2478,9 +2478,10 @@ getProducingInsertSliceLikeOp(OpResult result,
   // tiling and retrieve the `tensor.insert_slice` operation used to construct
   // the result.
   while (loops.size() != 1) {
-    if (result.getOwner() != loops.front())
+    LoopLikeOpInterface loop = loops.front();
+    if (result.getOwner() != loop)
       return std::nullopt;
-    auto forOp = dyn_cast<scf::ForOp>(loops.front());
+    auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
     if (!forOp)
       return std::nullopt;
     auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
@@ -2491,9 +2492,10 @@ getProducingInsertSliceLikeOp(OpResult result,
     result = innerForResult;
     loops = loops.drop_front();
   }
-  if (result.getOwner() != loops.front())
+  LoopLikeOpInterface loop = loops.front();
+  if (result.getOwner() != loop)
     return std::nullopt;
-  auto forOp = dyn_cast<scf::ForOp>(loops.front());
+  auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
   if (!forOp)
     return std::nullopt;
   auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());

>From bf8c1de8b53c5421e81de919ee89e693c0119fca Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 12 Nov 2025 13:19:02 -0800
Subject: [PATCH 3/4] Fix linter error.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../TilingInterface/TestTilingInterfaceTransformOps.cpp        | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 51dac0e866254..194c052eb4682 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -207,7 +207,8 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
   Operation *consumer = *state.getPayloadOps(getConsumer()).begin();
 
   SmallVector<LoopLikeOpInterface> loops;
-  // Since the matcher works inside-out, we need to iterate the loops in reverse.
+  // Since the matcher works inside-out, we need to iterate the loops in
+  // reverse.
   for (auto loop : llvm::reverse(getLoops())) {
     auto loopLikeOp =
         dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(loop).begin());

>From 8d4a2d148c8c96b1bf03e786b82edfaeb0b4dfe1 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 20 Nov 2025 11:51:35 -0800
Subject: [PATCH 4/4] Address comments.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../SCF/Transforms/TileUsingInterface.cpp     | 53 ++++++++-----------
 .../TestTilingInterfaceTransformOps.cpp       |  7 +--
 .../TestTilingInterfaceTransformOps.td        | 17 ++++--
 3 files changed, 38 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 03ce5555f56ff..009c2c3537411 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -2427,21 +2427,17 @@ mlir::scf::tileAndFuseConsumerOfSlices(
 
   // Get the consumer of scf.for for the result yielded by
   // tensor.insert_slice/parallel_insert_slice.
-  SmallVector<OpOperand *> consumerOpOperands;
-  Operation *consumerOp;
-  {
-    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
-        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
-    if (failed(maybeConsumerOpOperand)) {
-      return rewriter.notifyMatchFailure(candidateSlices.front(),
-                                         "could not fetch consumer to fuse");
-    }
-    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
-    consumerOp = consumerOpOperands.front()->getOwner();
+  FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
+      getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+  if (failed(maybeConsumerOpOperands)) {
+    return rewriter.notifyMatchFailure(candidateSlices.front(),
+                                       "could not fetch consumer to fuse");
   }
+  Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();
 
-  return tileAndFuseConsumerOfSlicesImpl(
-      rewriter, consumerOp, consumerOpOperands, candidateSlices, loops);
+  return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp,
+                                         maybeConsumerOpOperands.value(),
+                                         candidateSlices, loops);
 }
 
 /// For a given `result` of a `forallOp` return the
@@ -2455,21 +2451,19 @@ getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
   SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
   // If the number of combining ops is not 1, then this is unexpected. Return
   // nullopt.
-  if (combiningOps.size() != 1) {
+  if (combiningOps.size() != 1)
     return std::nullopt;
-  }
   return combiningOps[0];
 }
 
 /// For a given result of the loop nest that is a tiled loop nest, return the
 /// insert slice-like op that is used for consumer fusion
-std::optional<Operation *>
+static std::optional<Operation *>
 getProducingInsertSliceLikeOp(OpResult result,
                               ArrayRef<LoopLikeOpInterface> loops) {
   assert(!loops.empty() && "Expected loops to be not empty");
-  LoopLikeOpInterface outermostLoop = loops.front();
-
-  if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) {
+  LoopLikeOpInterface outerMostLoop = loops.front();
+  if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
     assert(loops.size() == 1 &&
            "expected only a single loop when tiling using scf.forall");
     return getProducingParallelInsertSlice(forallOp, result);
@@ -2485,7 +2479,7 @@ getProducingInsertSliceLikeOp(OpResult result,
     if (!forOp)
       return std::nullopt;
     auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
-    OpResult innerForResult =
+    auto innerForResult =
         dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
     if (!innerForResult)
       return std::nullopt;
@@ -2507,27 +2501,26 @@ getProducingInsertSliceLikeOp(OpResult result,
 }
 
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user,
+mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
                                MutableArrayRef<LoopLikeOpInterface> loops) {
-  // Only handle users that implement the `TilingInterface`.
-  if (!isa<TilingInterface>(user)) {
+  if (!isa<TilingInterface>(consumer)) {
     return rewriter.notifyMatchFailure(
-        user, "unhandled user that does not implement TilingInterface");
+        consumer, "unhandled consumer that does not implement TilingInterface");
   }
 
   // Return if `loops` is empty, return an error for now. Caller is expected
   // to handle this case.
   if (loops.empty()) {
     return rewriter.notifyMatchFailure(
-        user, "cannot call tile and fuse consumer with an empty loop nest");
+        consumer, "cannot call tile and fuse consumer with an empty loop nest");
   }
 
   LoopLikeOpInterface outermostLoop = loops.front();
 
-  // Collect the operands of the user that come from the outermost loop of the
-  // loop nest.
+  // Collect the operands of the consumer that come from the outermost loop of
+  // the loop nest.
   SmallVector<OpOperand *> consumerFusableOperands;
-  for (OpOperand &opOperand : user->getOpOperands()) {
+  for (OpOperand &opOperand : consumer->getOpOperands()) {
     if (opOperand.get().getDefiningOp() == outermostLoop) {
       consumerFusableOperands.push_back(&opOperand);
     }
@@ -2549,13 +2542,13 @@ mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user,
         getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
     if (!slice) {
       return rewriter.notifyMatchFailure(
-          user,
+          consumer,
           "couldnt find producing insert-slice like operation for operand");
     }
     candidateSlices.push_back(slice.value());
   }
   return tileAndFuseConsumerOfSlicesImpl(
-      rewriter, user, consumerFusableOperands, candidateSlices, loops);
+      rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 194c052eb4682..74bdaaa3d7c57 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -170,20 +170,18 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
 // TestFuseConsumerOp
 //===----------------------------------------------------------------------===//
 
-/// Apply fusing of consumer transformation to all payload ops and store both
-/// the original consumer operation as well as the fused consumer operation.
+/// Fuse the consumer and store both the original consumer operation as well as
+/// the fused consumer operation.
 static LogicalResult
 applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
                   Operation *consumer,
                   MutableArrayRef<LoopLikeOpInterface> loops,
                   TransformResults &transformResults) {
   SmallVector<Operation *> fusedConsumerOps;
-
   rewriter.setInsertionPoint(consumer);
 
   FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
       scf::tileAndFuseConsumer(rewriter, consumer, loops);
-
   if (failed(fuseConsumerResults))
     return consumer->emitOpError("failed to fuse consumer of slice");
 
@@ -192,7 +190,6 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
        fuseConsumerResults->tiledAndFusedConsumerOperands) {
     fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
   }
-
   transformResults.set(transformOp->getOpResult(0), fusedConsumerOps);
   for (auto [index, loop] : llvm::enumerate(loops)) {
     transformResults.set(transformOp->getOpResult(index + 1), {loop});
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index bfefad02418ac..29669bd0930ed 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -55,8 +55,13 @@ def TestFuseConsumerUsingSliceOp : Op<Transform_Dialect, "test.fuse_consumer_usi
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
         ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    Fuses the consumer of the operation pointed to by the target handle
-    using the options provided as attributes.
+    For the `insert_slice`-like operations (that are typically generated through tiling),
+    within the loop nests passed in as `loops` (that are typically generated through tiling),
+    find the consumer that these slices map to (have to be the same consumer) and fuse
+    the consumer into the loop.
+
+    Returns a handle to the original consumer operation and the consumer operation after
+    fusion.
   }];
 
   let arguments = (ins 
@@ -78,8 +83,12 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
         ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    Fuses the consumer of the operation pointed to by the target handle
-    using the options provided as attributes.
+    For the `consumer` that uses the result of the outer-most loop of a loop nest passed in 
+    as `loops` (that are typically generated through tiling), fuse the consumer into the
+    loop.
+
+    Returns a handle to the consumer operation after fusion and the loops that might be
+    modified.
   }];
 
   let arguments = (ins 



More information about the Mlir-commits mailing list