[Mlir-commits] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 12 16:08:00 PST 2024


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/77874

>From be18c255933990923da9d679c01091a3501f1a20 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 4 Jan 2024 18:50:09 -0800
Subject: [PATCH] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling
 using SCF to unify tiling with `scf.for` and `scf.forall`.

Using `LoopLikeOpInterface` as the basis for the implementation
unifies all the tiling logic for both `scf.for` and `scf.forall`. The
only difference is the actual loop generation.
Instead of many entry points for each loop type, the loop type is now
passed as part of the options passed to the tiling method.

This is a breaking change with the following changes

1) The `scf::tileUsingSCFForOp` is renamed to `scf::tileUsingSCF`
2) The `scf::tileUsingSCFForallOp` is deprecated. The same
   functionality is obtained by using `scf::tileUsingSCF` and setting
   the loop type in `scf::SCFTilingOptions` passed into this method to
   `scf::SCFTilingOptions::LoopType::ForallOp` (using the
   `setLoopType` method).
3) The `scf::tileConsumerAndFusedProducerGreedilyUsingSCFForOp` is
   renamed to `scf::tileConsumerAndFuseProducerUsingSCF`. The use of
   the `controlFn` in `scf::SCFTileAndFuseOptions` allows implementing
   any strategy with the default callback implemeting the greedy fusion.
4) The `scf::SCFTilingResult` and `scf::SCFTileAndFuseResult` now use
   `SmallVector<LoopLikeOpInterface>`.
5) To make `scf::ForallOp` implement the parts of
   `LoopLikeOpInterface` needed, the `getOutputBlockArguments()`
   method is replaced with `getRegionIterArgs()`

These changes now bring the tiling and fusion capabilities using
`scf.forall` on par with what was already supported
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  19 +-
 .../SCF/Transforms/TileUsingInterface.h       |  41 +-
 .../TransformOps/LinalgTransformOps.cpp       |  10 +-
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp |   4 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  50 +-
 .../SCF/Transforms/TileUsingInterface.cpp     | 813 +++++++++---------
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |   8 +-
 mlir/lib/Interfaces/LoopLikeInterface.cpp     |  36 +-
 .../Linalg/generalize-tensor-unpack-tile.mlir |   8 +-
 mlir/test/Dialect/Linalg/tile-conv.mlir       |   2 +-
 mlir/test/Dialect/Linalg/tile-tensors.mlir    |   2 +-
 ...-op-hoist-pad-build-packing-loop-nest.mlir |   4 +-
 .../Linalg/transform-op-hoist-pad.mlir        |   2 +
 .../transform-op-peel-and-vectorize.mlir      |   4 +-
 mlir/test/Dialect/Tensor/tiling.mlir          |  24 +-
 .../tile-and-fuse-using-scfforall.mlir        | 176 ++++
 .../tile-pad-using-interface.mlir             |   5 +-
 .../TilingInterface/tile-using-interface.mlir |  18 +-
 .../TilingInterface/tile-using-scfforall.mlir | 150 +++-
 .../TestTilingInterfaceTransformOps.cpp       | 136 ++-
 .../TestTilingInterfaceTransformOps.td        |  26 +-
 21 files changed, 1010 insertions(+), 528 deletions(-)
 create mode 100644 mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8d65d3dd820baf..819d88df0ae57f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -135,9 +135,9 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
-       ["getInitsMutable", "getSingleInductionVar", "getSingleLowerBound",
-        "getSingleStep", "getSingleUpperBound", "getYieldedValuesMutable",
-        "getLoopResults", "promoteIfSingleIteration",
+       ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
+        "getSingleLowerBound", "getSingleStep", "getSingleUpperBound",
+        "getYieldedValuesMutable", "getLoopResults", "promoteIfSingleIteration",
         "replaceWithAdditionalYields"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
@@ -259,10 +259,6 @@ def ForOp : SCF_Op<"for",
 
     Value getInductionVar() { return getBody()->getArgument(0); }
 
-    Block::BlockArgListType getRegionIterArgs() {
-      return getBody()->getArguments().drop_front(getNumInductionVars());
-    }
-
     /// Return the `index`-th region iteration argument.
     BlockArgument getRegionIterArg(unsigned index) {
       assert(index < getNumRegionIterArgs() &&
@@ -304,8 +300,9 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["promoteIfSingleIteration", "getSingleInductionVar",
-          "getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
+          ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
+           "getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
+           "promoteIfSingleIteration", "replaceWithAdditionalYields"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -585,10 +582,6 @@ def ForallOp : SCF_Op<"forall", [
                                     getNumDynamicControlOperands() + getRank());
     }
 
-    ArrayRef<BlockArgument> getOutputBlockArguments() {
-      return getBody()->getArguments().drop_front(getRank());
-    }
-
     ::mlir::ValueRange getInductionVars() {
       return getBody()->getArguments().take_front(getRank());
     }
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 5d2d78e6e6165b..965ef9e203be28 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
 
 #include <deque>
@@ -52,6 +53,14 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  /// Specify which loop construct to use for tile and fuse.
+  enum class LoopType { ForOp, ForallOp };
+  LoopType loopType = LoopType::ForOp;
+  SCFTilingOptions &setLoopType(LoopType type) {
+    loopType = type;
+    return *this;
+  }
+
   /// Specify mapping of loops to devices. This is only respected when the loop
   /// constructs support such a mapping (like `scf.forall`). Will be ignored
   /// when using loop constructs that dont support such a mapping (like
@@ -71,7 +80,7 @@ struct SCFTilingResult {
   /// of the last op.
   SmallVector<Operation *> tiledOps;
   /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<Operation *> loops;
+  SmallVector<LoopLikeOpInterface> loops;
   /// Values to use as replacements for the untiled op. Is the same size as the
   /// number of results of the untiled op.
   SmallVector<Value> replacements;
@@ -79,15 +88,9 @@ struct SCFTilingResult {
 
 /// Method to tile an op that implements the `TilingInterface` using
 /// `scf.for` for iterating over the tiles.
-FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter,
-                                             TilingInterface op,
-                                             const SCFTilingOptions &options);
-
-/// Method to tile an op that implements the `TilingInterface` using
-/// `scf.forall`.
-FailureOr<SCFTilingResult>
-tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
-                     const SCFTilingOptions &options);
+FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
+                                        TilingInterface op,
+                                        const SCFTilingOptions &options);
 
 /// Options used to control tile + fuse.
 struct SCFTileAndFuseOptions {
@@ -135,7 +138,7 @@ struct SCFFuseProducerOfSliceResult {
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                            tensor::ExtractSliceOp candidateSliceOp,
-                           MutableArrayRef<scf::ForOp> loops);
+                           MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Reconstruct the fused producer from within the tiled-and-fused code. Based
 /// on the slice of the producer computed in place it is possible that within
@@ -187,10 +190,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 /// where `%0` had other uses as well. If not reconstructed from within the loop
 /// body, uses of `%0` could not be replaced, making it still live and the
 /// fusion immaterial.
-void yieldReplacementForFusedProducer(
+LogicalResult yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<scf::ForOp> loops);
+    MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Transformation information returned after tile and fuse.
 struct SCFTileAndFuseResult {
@@ -201,7 +204,7 @@ struct SCFTileAndFuseResult {
   /// generated operation.
   llvm::SetVector<Operation *> tiledAndFusedOps;
   /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<Operation *> loops;
+  SmallVector<LoopLikeOpInterface> loops;
   /// The replacement values to use for the tiled and fused operations.
   llvm::DenseMap<Value, Value> replacements;
 };
@@ -232,9 +235,9 @@ struct SCFTileAndFuseResult {
 /// }
 /// ```
 FailureOr<SCFTileAndFuseResult>
-tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
-    RewriterBase &rewriter, TilingInterface consumer,
-    const SCFTileAndFuseOptions &options);
+tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
+                                     TilingInterface consumer,
+                                     const SCFTileAndFuseOptions &options);
 
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
@@ -249,8 +252,8 @@ struct SCFReductionTilingResult {
   Operation *mergeOp;
   /// Initial op
   Operation *initialOp;
-  /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<scf::ForOp> loops;
+  /// The loop operations that iterate over the tiles.
+  SmallVector<LoopLikeOpInterface> loops;
 };
 
 /// Method to tile a reduction and generate a parallel op within a serial loop.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 97d2b4a3be5c56..5efa2b3eb7476f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -485,8 +485,8 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
       tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
       [&](TilingInterface tilingInterfaceOp)
           -> FailureOr<scf::SCFTileAndFuseResult> {
-        return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
-            rewriter, tilingInterfaceOp, tileAndFuseOptions);
+        return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
+                                                    tileAndFuseOptions);
       });
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
@@ -584,7 +584,7 @@ static Operation *replaceForAllWithNewSignature(
   Operation *firstYieldOp = yieldingOps.front();
   rewriter.setInsertionPoint(firstYieldOp);
   Value src = tileAndFuseResult.tiledValues[0];
-  Value dst = newforallOp.getOutputBlockArguments().back();
+  Value dst = newforallOp.getRegionIterArgs().back();
   SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
   rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
                                                  dst, offsets, sizes, strides);
@@ -2063,7 +2063,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
   });
   SmallVector<int64_t> emptyTileSizes;
   rewriter.setInsertionPoint(target);
-  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
+  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
       rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
   if (failed(maybeTilingResult))
     return emitDefaultDefiniteFailure(target);
@@ -2647,7 +2647,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
 
     tilingOptions.setInterchange(getInterchange());
     FailureOr<scf::SCFTilingResult> maybeTilingResult =
-        tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
+        tileUsingSCF(rewriter, tilingInterface, tilingOptions);
     if (failed(maybeTilingResult))
       return DiagnosedSilenceableFailure::definiteFailure();
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7f3ab1f1a24b2f..339d06cdeaf60a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -358,7 +358,7 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
 
   // 3. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForallOp.
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
   Operation *tiledOp = nullptr;
   SmallVector<Value> tiledValues;
   {
@@ -695,7 +695,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   // 4. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForallOp.
   SmallVector<Value> tilingResults;
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
   {
     // 4.a. RAII guard, inserting within forallOp, before terminator.
     OpBuilder::InsertionGuard g(b);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5570c2ec688c8a..7dbc8f250b25eb 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -527,6 +527,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
 
 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
 
+Block::BlockArgListType ForOp::getRegionIterArgs() {
+  return getBody()->getArguments().drop_front(getNumInductionVars());
+}
+
 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
   return getInitArgsMutable();
 }
@@ -622,6 +626,47 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
   return success();
 }
 
+Block::BlockArgListType ForallOp::getRegionIterArgs() {
+  return getBody()->getArguments().drop_front(getRank());
+}
+
+MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
+  return getOutputsMutable();
+}
+
+FailureOr<LoopLikeOpInterface>
+ForallOp::replaceWithAdditionalYields(RewriterBase &rewriter,
+                                      ValueRange newInitOperands,
+                                      bool replaceInitOperandUsesInLoop,
+                                      const NewYieldValuesFn &newYieldValueFn) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(getOperation());
+  auto inits = llvm::to_vector(getOutputs());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  auto newLoop = rewriter.create<scf::ForallOp>(
+      getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
+      inits, getMapping(), [](OpBuilder &, Location, ValueRange) {});
+
+  // Move the region of the current block to the newly created op.
+  Block *newLoopBody = newLoop.getBody();
+  rewriter.mergeBlocks(
+      getBody(), newLoopBody,
+      newLoopBody->getArguments().take_front(getBody()->getNumArguments()));
+
+  // Update the terminator.
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
+    rewriter.setInsertionPointToEnd(terminator.getBody());
+    newYieldValueFn(
+        rewriter, getLoc(),
+        newLoopBody->getArguments().take_back(newInitOperands.size()));
+  }
+  rewriter.replaceOp(getOperation(),
+                     newLoop->getResults().take_front(getNumResults()));
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
 /// Promotes the loop body of a scf::ForallOp to its containing block.
 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
   OpBuilder::InsertionGuard g(rewriter);
@@ -1630,9 +1675,8 @@ struct FoldTensorCastOfOutputIntoForallOp
     // mapped to the tensor.cast old-typed results of the output bbArgs. The
     // destination have to be updated to point to the output bbArgs directly.
     auto terminator = newForallOp.getTerminator();
-    for (auto [yieldingOp, outputBlockArg] :
-         llvm::zip(terminator.getYieldingOps(),
-                   newForallOp.getOutputBlockArguments())) {
+    for (auto [yieldingOp, outputBlockArg] : llvm::zip(
+             terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
       auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
       insertSliceOp.getDestMutable().assign(outputBlockArg);
     }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 38e0625d7ce093..838edb7a00e0fc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -55,32 +55,8 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
   return filledVector;
 }
 
-/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
-template <typename SrcOpTy>
-static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
-  return llvm::to_vector(
-      llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
-}
-template <typename SrcOpTy>
-static SmallVector<Operation *>
-getAsOperations(const SmallVector<SrcOpTy> &ops) {
-  return getAsOperations(ArrayRef<SrcOpTy>(ops));
-}
-
-/// Convert a list of `Operation *` to a list of `DstOpTy.
-template <typename DstOpTy>
-static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
-  return llvm::to_vector(
-      llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
-}
-template <typename DstOpTy>
-static SmallVector<DstOpTy>
-castToTypedOperations(const SmallVector<Operation *> &ops) {
-  return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
-}
-
 //===----------------------------------------------------------------------===//
-// tileUsingSCFForOp implementation.
+// tileUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
 // Check if `stride` evenly divides the trip count `size - offset`.
@@ -135,66 +111,201 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
   return clonedOp;
 }
 
-/// Generate an empty loop nest that represents the tiled loop nest shell.
+/// Type of the call back function used to generate the body of the tiled
+/// loops. The loop generation methods use this callback to generate
+/// the body of the inner-most tile loop. The call back is provided with
+/// - `rewriter`: with insertion point set to the end of the inner most loop
+///    body
+/// - `ivs`: the induction variables for the surrounding loops.
+/// - `regionIterArgs`: the basic block arguments of the inner most loop that
+///   correspond to the init/result of the loop.
+/// - `resultOffsets/Sizes`: The call back returns the result of tiling the
+///   operation, but in `resultOffsets` and `resultSizes` the callback function
+///   is expected to populate the offsets/sizes to use while inserting the
+///   result back into the corresponding `regionIterArgs`. These values
+///   are used by the loop generation methods to create the appropriate yield
+///   values from the inner most loop.
+/// The call back returns the `TilingResult` obtained from tiling the operation.
+using TileLoopBodyFn = std::function<FailureOr<TilingResult>(
+    RewriterBase &rewriter, Location loc, ValueRange ivs,
+    ValueRange regionIterArgs,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
+
+/// Generate the tile-loop nest using `scf.for` operation.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
-/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
-///   the tile processed within the inner most loop.
-/// Note that this methods adds `scf.yield` operation for all but the innermost
-/// loop. These yield the value returned by the immediately inner loop. The
-/// caller is expected to add the scf.yield operation for the innermost loop.
-static SmallVector<scf::ForOp> generateTileLoopNest(
-    OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
-    SmallVector<OpFoldResult> &sizes, ValueRange destinationTensors = {}) {
-  if (loopRanges.empty())
-    return {};
+/// - `destinationTensors` are the init values to use for the outer most loop.
+/// - `tileLoopBodyFn` is called to generated the loop body of the inner most
+///    loop.
+/// - `loops` is an in-out parameter into which the generated loops are
+///    populated.
+/// The `TilingResult` returned by calling `tileLoopBodyFn` is returned back
+/// to the caller.
+static FailureOr<TilingResult> generateLoopNestUsingForOp(
+    RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+    ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
+    TileLoopBodyFn tileLoopBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
+  assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
          "expected as many tile sizes as loop ranges");
-  OpBuilder::InsertionGuard guard(builder);
-  SmallVector<scf::ForOp> loops;
-  offsets.resize(loopRanges.size());
-  sizes.resize(loopRanges.size());
-
-  for (auto loopRange : llvm::enumerate(loopRanges)) {
-    Value offset =
-        getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
-    Value size =
-        getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
-    Value tileSize = getValueOrCreateConstantIndexOp(
-        builder, loc, tileSizes[loopRange.index()]);
+  OpBuilder::InsertionGuard guard(rewriter);
+  SmallVector<Value> ivs;
+
+  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
     // No loops if tile size is zero. Set offset and size to the loop
     // offset and size.
-    if (matchPattern(tileSize, m_Zero())) {
-      offsets[loopRange.index()] = offset;
-      sizes[loopRange.index()] = size;
+    if (isConstantIntValue(tileSize, 0))
       continue;
-    }
 
-    auto loop = builder.create<scf::ForOp>(
-        loc, offset, size, tileSize, destinationTensors,
-        [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
-            ValueRange /*iterArgs*/) {
-          sizes[loopRange.index()] =
-              getBoundedTileSize(bodyBuilder, bodyLoc, loopRange.value(), iv,
-                                 getAsOpFoldResult(tileSize));
-        });
-    offsets[loopRange.index()] = loop.getInductionVar();
+    Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+    Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+    Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+    auto loop =
+        rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
+                                    [](OpBuilder &bodyBuilder, Location bodyLoc,
+                                       Value iv, ValueRange /*iterArgs*/) {});
     loops.push_back(loop);
-    builder.setInsertionPointToEnd(loop.getBody());
+    ivs.push_back(loop.getInductionVar());
+    rewriter.setInsertionPointToEnd(loop.getBody());
     destinationTensors = loop.getRegionIterArgs();
   }
 
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+  FailureOr<TilingResult> tilingResult = tileLoopBodyFn(
+      rewriter, loc, ivs, destinationTensors, resultOffsets, resultSizes);
+  if (failed(tilingResult)) {
+    return rewriter.notifyMatchFailure(
+        loc, "failed to generate inner tile loop body");
+  }
+  if (loops.empty())
+    return tilingResult;
+
+  // 6. Yield all the results of the tiled operation.
+  SmallVector<Value> yieldedValues;
+  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
+       llvm::zip_equal(tilingResult->tiledValues, destinationTensors,
+                       resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> resultStride(resultOffset.size(),
+                                           rewriter.getIndexAttr(1));
+    auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
+        loc, tiledValue, destinationTensor, resultOffset, resultSize,
+        resultStride);
+    yieldedValues.push_back(insertSlice);
+  }
+  rewriter.create<scf::YieldOp>(loc, yieldedValues);
+
   // Add the scf.yield operations for all the outer loops.
-  if (!loops.empty()) {
-    for (auto [outerLoop, innerLoop] :
-         llvm::zip_equal(MutableArrayRef(loops).drop_back(),
-                         MutableArrayRef(loops).drop_front())) {
-      builder.setInsertionPointToEnd(outerLoop.getBody());
-      builder.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop.getResults());
-    }
+  for (auto [outerLoop, innerLoop] :
+       llvm::zip_equal(MutableArrayRef(loops).drop_back(),
+                       MutableArrayRef(loops).drop_front())) {
+    rewriter.setInsertionPointToEnd(
+        cast<scf::ForOp>(outerLoop.getOperation()).getBody());
+    rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
   }
-  return loops;
+  return tilingResult;
+}
+
+/// Generate the tile-loop nest using `scf.forall` operation.
+/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
+/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
+/// - `destinationTensors` are the init values to use for the outer most loop.
+/// - `mappingVector` is the mapping attributes to use for loop construction.
+///   Can be empty.
+/// - `tileLoopBodyFn` is called to generated the loop body of the inner most
+///    loop.
+/// - `loops` is an in-out parameter into which the generated loops are
+///    populated.
+/// The `TilingResult` returned by calling `tileLoopBodyFn` is returned back
+/// to the caller.
+static FailureOr<TilingResult> generateLoopNestUsingForallOp(
+    RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+    ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
+    ValueRange destinationTensors, TileLoopBodyFn tiledBodyFn,
+    SmallVector<LoopLikeOpInterface> &loops) {
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  assert(!loopRanges.empty() && "unexpected empty loop ranges");
+  assert(loopRanges.size() == tileSizes.size() &&
+         "expected as many tile sizes as loop ranges");
+  OpBuilder::InsertionGuard guard(rewriter);
+  SmallVector<OpFoldResult> offsets(loopRanges.size()),
+      sizes(loopRanges.size());
+
+  for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
+    if (isConstantIntValue(tileSize, 0))
+      continue;
+    lbs.push_back(loopRange.offset);
+    ubs.push_back(loopRange.size);
+    steps.push_back(tileSize);
+  }
+  assert(!lbs.empty() && "Expected at least one loop range");
+
+  std::optional<ArrayAttr> mappingAttr;
+  if (!mappingVector.empty())
+    mappingAttr = rewriter.getArrayAttr(mappingVector);
+
+  auto forallOp = rewriter.create<scf::ForallOp>(
+      loc, lbs, ubs, steps, destinationTensors, mappingAttr);
+  loops.push_back(forallOp);
+
+  rewriter.setInsertionPoint(forallOp.getTerminator());
+  destinationTensors = forallOp.getRegionOutArgs();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+  FailureOr<TilingResult> tilingResult =
+      tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
+                  destinationTensors, resultOffsets, resultSizes);
+
+  if (failed(tilingResult))
+    return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
+
+  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
+  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
+       llvm::zip_equal(tilingResult->tiledValues, destinationTensors,
+                       resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> resultStride(resultOffset.size(),
+                                           rewriter.getIndexAttr(1));
+
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        loc, tiledValue, destinationTensor, resultOffset, resultSize,
+        resultStride);
+  }
+  return tilingResult;
+}
+
+/// Generate the tile-loop nest using the loop construct specifed in `options`.
+/// - `options`: Tiling options specified.
+/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
+/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
+/// - `destinationTensors` are the init values to use for the outer most loop.
+/// - `tileLoopBodyFn` is called to generated the loop body of the inner most
+///    loop.
+/// - `loops` is an in-out parameter into which the generated loops are
+///    populated.
+/// The `TilingResult` returned by calling `tileLoopBodyFn` is returned back
+/// to the caller.
+static FailureOr<TilingResult>
+generateLoopNest(RewriterBase &rewriter, Location loc,
+                 const scf::SCFTilingOptions &options,
+                 ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
+                 ValueRange destinationTensors, TileLoopBodyFn tiledBodyFn,
+                 SmallVector<LoopLikeOpInterface> &loops) {
+  // If the tile sizes are all zero, no loops are generated. Just call the
+  // callback function to handle untiled case.
+  if (llvm::all_of(tileSizes, isZeroIndex)) {
+    SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+    return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
+                       resultOffsets, resultSizes);
+  }
+  if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
+    return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
+                                      destinationTensors, tiledBodyFn, loops);
+  }
+  if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
+    return generateLoopNestUsingForallOp(
+        rewriter, loc, loopRanges, tileSizes, options.mappingVector,
+        destinationTensors, tiledBodyFn, loops);
+  }
+  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
 
 /// Method to add new init values to a loop nest. Updates `loops` in-place with
@@ -202,26 +313,26 @@ static SmallVector<scf::ForOp> generateTileLoopNest(
 /// The outer-loops are updated to yield the new result values of the inner
 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
 /// the additional values to yield form the innermost loop.
-static void addInitOperandsToLoopNest(
-    RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loops,
-    ValueRange newInitValues,
-    llvm::function_ref<SmallVector<Value>(RewriterBase &rewriter, Value iv,
-                                          ValueRange newRegionIterArgs)>
-        getNewYieldValsFn) {
+static LogicalResult addInitOperandsToLoopNest(
+    RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
+    ValueRange newInitValues, const NewYieldValuesFn &getNewYieldValsFn) {
   SmallVector<scf::ForOp> newLoops;
   if (loops.empty())
-    return;
+    return success();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(loops.front());
-  for (auto &loop : loops) {
+  for (auto &loop : loops.drop_back()) {
     rewriter.setInsertionPoint(loop);
 
+    // if loops.size() > 1 we assume that scf.for is used for the loops.
+    auto forLoop = cast<scf::ForOp>(loop.getOperation());
+
     // Create a new loop with the new init values for this loop.
-    SmallVector<Value> newInits = llvm::to_vector(loop.getInitArgs());
+    SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
     newInits.append(newInitValues.begin(), newInitValues.end());
     auto newLoop = rewriter.create<scf::ForOp>(
-        loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
-        loop.getStep(), newInits,
+        forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
+        forLoop.getStep(), newInits,
         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
 
     // Merge the body of the new loop with the body of the old loops.
@@ -230,48 +341,47 @@ static void addInitOperandsToLoopNest(
     auto newRegionIterArgs = newLoop.getRegionIterArgs();
     sourceBlockArgs.append(
         newRegionIterArgs.begin(),
-        std::next(newRegionIterArgs.begin(), loop.getNumResults()));
-    rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(), sourceBlockArgs);
-    rewriter.replaceOp(loop,
-                       newLoop.getResults().take_front(loop.getNumResults()));
+        std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
+    rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
+    rewriter.replaceOp(
+        forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
     loop = newLoop;
     newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
   }
 
   // Update the loop body of the innermost loop to get new yield values.
-  scf::ForOp innerMostLoop = loops.back();
-  auto innerMostYieldOp =
-      cast<scf::YieldOp>(innerMostLoop.getBody()->getTerminator());
-  rewriter.setInsertionPoint(innerMostYieldOp);
-  SmallVector<Value> newYieldVals =
-      getNewYieldValsFn(rewriter, innerMostLoop.getInductionVar(),
-                        innerMostLoop.getRegionIterArgs());
-  SmallVector<Value> newYieldOperands =
-      llvm::to_vector(innerMostYieldOp->getOperands());
-  newYieldOperands.append(newYieldVals);
-  rewriter.replaceOpWithNewOp<scf::YieldOp>(innerMostYieldOp, newYieldOperands);
+  LoopLikeOpInterface innerMostLoop = loops.back();
+  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
+      innerMostLoop.replaceWithAdditionalYields(rewriter, newInitValues, false,
+                                                getNewYieldValsFn);
+  if (failed(newInnerMostLoop))
+    return innerMostLoop.emitOpError("failed to return additional yields");
+  loops.back() = newInnerMostLoop.value();
 
   // Make all other loops except the innermost loops yield the values returned
   // by the inner loop.
   for (auto [outerLoop, innerLoop] :
        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+    // Again assume that all the outer loops are scf.for operations.
+    auto outerForLoop = cast<scf::ForOp>(outerLoop);
     auto outerLoopYield =
-        cast<scf::YieldOp>(outerLoop.getBody()->getTerminator());
+        cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
     SmallVector<Value> newYields =
         llvm::to_vector(outerLoopYield.getOperands());
     ValueRange additionalYields =
-        innerLoop.getResults().take_back(newInitValues.size());
+        innerLoop->getResults().take_back(newInitValues.size());
     newYields.append(additionalYields.begin(), additionalYields.end());
     rewriter.setInsertionPoint(outerLoopYield);
     rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
   }
+  return success();
 }
 
 /// Implementation of tiling transformation of `op` that implements the
 /// `TilingInterface` using `scf.for` to iterate over the tiles.
 FailureOr<scf::SCFTilingResult>
-mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
-                             const scf::SCFTilingOptions &options) {
+mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
+                        const scf::SCFTilingOptions &options) {
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointAfter(op);
 
@@ -288,145 +398,129 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
   // dimensions.
-  SmallVector<OpFoldResult> tileSizeVector =
+  SmallVector<OpFoldResult> tileSizes =
       options.tileSizeComputationFunction(rewriter, op);
-  if (tileSizeVector.size() < iterationDomain.size()) {
+  if (tileSizes.size() < iterationDomain.size()) {
     auto zero = rewriter.getIndexAttr(0);
-    tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
+    tileSizes.append(numLoops - tileSizes.size(), zero);
   }
 
-  // 3. Find the destination tensors to use for the operation.
-  SmallVector<Value> destinationTensors;
-  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
-                                             destinationTensors))) {
-    return rewriter.notifyMatchFailure(op,
-                                       "unable to create destination tensors");
+  // 3. If there is an interchange specified, permute the iteration domain and
+  // the tile sizes.
+  SmallVector<int64_t> interchangeVector;
+  if (!options.interchangeVector.empty()) {
+    interchangeVector = fillInterchangeVector(options.interchangeVector,
+                                              iterationDomain.size());
   }
-
-  SmallVector<OpFoldResult> offsets, sizes;
-  SmallVector<scf::ForOp> forLoops;
-  {
-    // If there is an interchange specified, permute the iteration domain and
-    // the tile sizes.
-    SmallVector<int64_t> interchangeVector;
-    if (!options.interchangeVector.empty()) {
-      interchangeVector = fillInterchangeVector(options.interchangeVector,
-                                                iterationDomain.size());
+  if (!interchangeVector.empty()) {
+    if (!isPermutationVector(interchangeVector)) {
+      return rewriter.notifyMatchFailure(
+          op, "invalid intechange vector, not a permutation of the entire "
+              "iteration space");
     }
-    if (!interchangeVector.empty()) {
-      if (!isPermutationVector(interchangeVector)) {
-        return rewriter.notifyMatchFailure(
-            op, "invalid intechange vector, not a permutation of the entire "
-                "iteration space");
-      }
 
-      applyPermutationToVector(iterationDomain, interchangeVector);
-      applyPermutationToVector(tileSizeVector, interchangeVector);
+    applyPermutationToVector(iterationDomain, interchangeVector);
+    applyPermutationToVector(tileSizes, interchangeVector);
+  }
+
+  // 4. Define the lambda function used later to generate the body of the
+  // innermost tiled loop.
+  auto innerTileLoopBodyFn =
+      [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
+          ValueRange regionIterArgs,
+          SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+          SmallVector<SmallVector<OpFoldResult>> &resultSizes)
+      -> FailureOr<TilingResult> {
+    // 4a. Compute the `offsets` and `sizes` to use for tiling.
+    SmallVector<OpFoldResult> offsets, sizes;
+    {
+      int materializedLoopNum = 0;
+      for (auto [tileSize, loopRange] :
+           llvm::zip_equal(tileSizes, iterationDomain)) {
+        if (isConstantIntValue(tileSize, 0)) {
+          offsets.push_back(loopRange.offset);
+          sizes.push_back(loopRange.size);
+          continue;
+        }
+        Value iv = ivs[materializedLoopNum++];
+        offsets.push_back(iv);
+        sizes.push_back(
+            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+      }
     }
 
-    // 4. Materialize an empty loop nest that iterates over the tiles. These
-    // loops for now do not return any values even if the original operation has
-    // results.
-    forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
-                                    tileSizeVector, offsets, sizes,
-                                    destinationTensors);
-
+    // 4b. If interchange was provided, apply inverse of the interchange
+    //     to get back the offsets/sizes in the order to be specified.
     if (!interchangeVector.empty()) {
       auto inversePermutation = invertPermutationVector(interchangeVector);
       applyPermutationToVector(offsets, inversePermutation);
       applyPermutationToVector(sizes, inversePermutation);
     }
-  }
 
-  LLVM_DEBUG({
-    if (!forLoops.empty()) {
-      llvm::dbgs() << "LoopNest shell :\n";
-      forLoops.front().dump();
-      llvm::dbgs() << "\n";
+    // 5. Generate the tiled implementation within the inner most loop.
+
+    // 5a. Clone the operation within the loop body.
+    auto clonedOp = cast<TilingInterface>(
+        cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
+
+    // 5b. Early return cloned op if tiling is not happening. We can not return
+    // the original op because it could lead to
+    // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
+    if (llvm::all_of(tileSizes, isZeroIndex))
+      return TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+
+    // 5c. Tile the cloned operation.
+    FailureOr<TilingResult> tilingResult =
+        clonedOp.getTiledImplementation(rewriter, offsets, sizes);
+    if (failed(tilingResult))
+      return op.emitOpError("failed to tile operation");
+
+    // 5d. Delete the cloned operation.
+    rewriter.eraseOp(clonedOp);
+
+    // 5e. Compute the offsets at which the result values are to be inserted
+    //     back into its destinations.
+    for (auto [index, tiledValue] :
+         llvm::enumerate(tilingResult->tiledValues)) {
+      SmallVector<OpFoldResult> resultOffset, resultSize;
+      if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
+                                          resultOffset, resultSize))) {
+        return rewriter.notifyMatchFailure(
+            op, "failed to get slice of result produced");
+      }
+      resultOffsets.emplace_back(std::move(resultOffset));
+      resultSizes.emplace_back(std::move(resultSize));
     }
-  });
-
-  // 5. Generate the tiled implementation within the inner most loop.
-  SmallVector<Value> clonedOpDestination = destinationTensors;
-  if (!forLoops.empty()) {
-    rewriter.setInsertionPointToEnd(forLoops.back().getBody());
-    clonedOpDestination =
-        llvm::map_to_vector(forLoops.back().getRegionIterArgs(),
-                            [](BlockArgument b) -> Value { return b; });
-  }
 
-  // 5a. Clone the operation within the loop body.
-  auto clonedOp = cast<TilingInterface>(
-      cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));
-
-  // 5b. Early return cloned op if tiling is not happening. We can not return
-  // the original op because it could lead to
-  // `rewriter.replaceOp(op, op->getResults())` and user would get crash.
-  if (llvm::all_of(tileSizeVector, isZeroIndex)) {
-    return scf::SCFTilingResult{/*tiledOps=*/{clonedOp}, /*loops=*/{},
-                                clonedOp->getResults()};
-  }
+    return tilingResult;
+  };
 
-  // 5c. Tile the cloned operation.
-  FailureOr<TilingResult> tiledImplementation =
-      clonedOp.getTiledImplementation(rewriter, offsets, sizes);
-  if (failed(tiledImplementation)) {
-    return rewriter.notifyMatchFailure(op, "failed to tile operation");
+  // 6. Find the destination tensors to use for the operation.
+  SmallVector<Value> destinationTensors;
+  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
+                                             destinationTensors))) {
+    return rewriter.notifyMatchFailure(op,
+                                       "unable to create destination tensors");
   }
 
-  // 5d. Delete the cloned operation.
-  rewriter.eraseOp(clonedOp);
+  // 7. Generate the tiled loops nest using the callback defined above.
+  SmallVector<LoopLikeOpInterface> loops;
+  FailureOr<TilingResult> tilingResult = generateLoopNest(
+      rewriter, op.getLoc(), options, iterationDomain, tileSizes,
+      destinationTensors, innerTileLoopBodyFn, loops);
+  if (failed(tilingResult))
+    return op.emitOpError("failed to generate tiling loops");
 
   // If loops are empty, the tiled op is used as the replacement for the untiled
   // op.
-  if (forLoops.empty()) {
-    return scf::SCFTilingResult{tiledImplementation->tiledOps,
-                                getAsOperations(forLoops),
-                                tiledImplementation->tiledValues};
-  }
-
-  if (op->getNumResults() == 0) {
-    // The innermost loop does not have a `scf.yield` yet. There is nothing to
-    // return, so generate an empty `scf.yield` operation.
-    rewriter.setInsertionPointToEnd(forLoops.back().getBody());
-    rewriter.create<scf::YieldOp>(op->getLoc());
-    return scf::SCFTilingResult{
-        tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
-  }
-
-  // 6. Yield all the results of the tiled operation.
-  int64_t numResults = op->getNumResults();
-  SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
-      resultSizesList(numResults);
-  SmallVector<Value> yieldedValues;
-  for (auto [index, tiledValue] :
-       llvm::enumerate(tiledImplementation->tiledValues)) {
-    SmallVector<OpFoldResult> resultOffsets, resultSizes;
-    if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
-                                        resultOffsets, resultSizes))) {
-      return rewriter.notifyMatchFailure(
-          op, "failed to get slice of result produced");
-    }
-    SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
-                                            rewriter.getIndexAttr(1));
-    auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
-        op->getLoc(), tiledValue, clonedOpDestination[index], resultOffsets,
-        resultSizes, resultStrides);
-    yieldedValues.push_back(insertSlice);
+  if (loops.empty()) {
+    return scf::SCFTilingResult{tilingResult->tiledOps, loops,
+                                tilingResult->tiledValues};
   }
-  rewriter.create<scf::YieldOp>(op->getLoc(), yieldedValues);
 
   SmallVector<Value> replacements = llvm::map_to_vector(
-      forLoops.front().getResults(), [](OpResult r) -> Value { return r; });
-  LLVM_DEBUG({
-    if (!forLoops.empty()) {
-      llvm::dbgs() << "After tiled implementation :\n";
-      forLoops.front().dump();
-      llvm::dbgs() << "\n";
-    }
-  });
-  return scf::SCFTilingResult{tiledImplementation->tiledOps,
-                              getAsOperations(forLoops), replacements};
+      loops.front()->getResults(), [](OpResult r) -> Value { return r; });
+  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
 }
 
 FailureOr<scf::SCFReductionTilingResult>
@@ -464,50 +558,71 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   if (failed(identityTensor))
     return b.notifyMatchFailure(op,
                                 "cannot create a tensor of identity value.");
-  // 3. Create the nested loops.
-  SmallVector<OpFoldResult> offsets, sizes;
-  SmallVector<scf::ForOp> loops =
-      generateTileLoopNest(b, loc, iterationDomain, tileSizesVector, offsets,
-                           sizes, identityTensor.value()->getResults());
-
-  // 4. Generate the tiled implementation within the inner most loop.
-  // 4a. Clone the operation within the loop body.
-  SmallVector<Value> clonedOpDestination =
+
+  // 3. Define the callback to use for generating the inner most tile loop body.
+  auto innerTileLoopBodyFn =
+      [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
+          ValueRange regionIterArgs,
+          SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+          SmallVector<SmallVector<OpFoldResult>> &resultSizes)
+      -> FailureOr<TilingResult> {
+    SmallVector<OpFoldResult> offsets, sizes;
+    {
+      int materializedLoopNum = 0;
+      for (auto [tileSize, loopRange] :
+           llvm::zip_equal(tileSizesVector, iterationDomain)) {
+        if (isConstantIntValue(tileSize, 0)) {
+          offsets.push_back(loopRange.offset);
+          sizes.push_back(loopRange.size);
+          continue;
+        }
+        Value iv = ivs[materializedLoopNum++];
+        offsets.push_back(iv);
+        sizes.push_back(
+            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+      }
+    }
+
+    // 4a. Clone the operation.
+    auto clonedOp = cast<PartialReductionOpInterface>(
+        cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+
+    // 4b. Tile the cloned operation.
+    Operation *parallelOp = clonedOp.tileToPartialReduction(
+        b, loc, regionIterArgs, offsets, sizes, reductionDims);
+    // 4c. Delete the cloned operation.
+    b.eraseOp(clonedOp);
+
+    // 4d. Compute the offsets and sizes needed to insert the result of the
+    // tiled value back into destination before yielding the destination.
+    SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
+    resultOffsets.emplace_back(std::move(outOffsets));
+
+    SmallVector<OpFoldResult> outSizes;
+    for (size_t i = 0; i < offsets.size(); i++) {
+      outSizes.push_back(
+          tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
+    }
+    resultSizes.emplace_back(std::move(outSizes));
+    return TilingResult{{parallelOp}, parallelOp->getResults()};
+  };
+
+  // 5. Generate the tiled implementation using the destination tensors.
+  SmallVector<Value> destinationTensors =
       llvm::map_to_vector(identityTensor.value()->getResults(),
                           [](OpResult res) -> Value { return res; });
-  if (!loops.empty()) {
-    b.setInsertionPointToEnd(loops.back().getBody());
-    clonedOpDestination =
-        llvm::map_to_vector(loops.back().getRegionIterArgs(),
-                            [](BlockArgument b) -> Value { return b; });
-  }
-  auto clonedOp = cast<PartialReductionOpInterface>(
-      cloneOpAndUpdateDestinationArgs(b, op, clonedOpDestination));
-
-  // 4b. Tile the cloned operation.
-  Operation *parallelOp = clonedOp.tileToPartialReduction(
-      b, loc, clonedOpDestination, offsets, sizes, reductionDims);
-  // 4c. Delete the cloned operation.
-  b.eraseOp(clonedOp);
-
-  SmallVector<OpFoldResult> outSizes;
-  for (size_t i = 0; i < offsets.size(); i++) {
-    outSizes.push_back(
-        tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
-  }
-  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-  SmallVector<OpFoldResult> outStrides(outOffsets.size(), b.getIndexAttr(1));
-  SmallVector<Value> yieldedVals;
-  auto bbArgs = loops.back().getRegionIterArgs();
-  for (auto [result, bbArg] : llvm::zip(parallelOp->getResults(), bbArgs)) {
-    Value insert = b.create<tensor::InsertSliceOp>(
-        loc, result, bbArg, outOffsets, outSizes, outStrides);
-    yieldedVals.push_back(insert);
-  }
-  b.create<scf::YieldOp>(loc, yieldedVals);
+
+  SmallVector<LoopLikeOpInterface> loops;
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+  FailureOr<TilingResult> tilingResult =
+      generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
+                       destinationTensors, innerTileLoopBodyFn, loops);
+  if (failed(tilingResult))
+    return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
 
   SmallVector<Value> replacements = llvm::map_to_vector(
-      loops.front().getResults(), [](OpResult r) -> Value { return r; });
+      loops.front()->getResults(), [](OpResult r) -> Value { return r; });
 
   // 5. Apply the merge reduction to combine all the partial values.
   b.setInsertionPointAfter(*loops.begin());
@@ -516,14 +631,14 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 
   SCFReductionTilingResult results;
   results.initialOp = *identityTensor;
-  results.loops = std::move(loops);
-  results.parallelTiledOp = parallelOp;
+  results.loops = loops;
+  results.parallelTiledOp = tilingResult->tiledOps.front();
   results.mergeOp = mergeOp;
   return results;
 }
 
 //===----------------------------------------------------------------------===//
-// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
+// tileConsumerAndFuseProducersUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
 /// Return the untiled producer whose slice is used in a tiled consumer. The
@@ -533,11 +648,11 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 /// no loop traversal needed, the second value of the returned tuple is empty.
 static std::tuple<OpResult, std::optional<OpOperand *>>
 getUntiledProducerFromSliceSource(OpOperand *source,
-                                  ArrayRef<scf::ForOp> loops) {
+                                  ArrayRef<LoopLikeOpInterface> loops) {
   std::optional<OpOperand *> destinationIterArg;
   auto loopIt = loops.rbegin();
   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
-    scf::ForOp loop = *loopIt;
+    auto loop = *loopIt;
     if (iterArg.getOwner()->getParentOp() != loop)
       break;
     source = loop.getTiedLoopInit(iterArg);
@@ -551,9 +666,9 @@ getUntiledProducerFromSliceSource(OpOperand *source,
 /// Implementation of fusing producer of a single slice by computing the
 /// slice of the producer in-place.
 std::optional<scf::SCFFuseProducerOfSliceResult>
-mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
-                                      tensor::ExtractSliceOp candidateSliceOp,
-                                      MutableArrayRef<scf::ForOp> loops) {
+mlir::scf::tileAndFuseProducerOfSlice(
+    RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
   auto [fusableProducer, destinationInitArg] =
@@ -666,12 +781,12 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 }
 
 /// Reconstruct the fused producer from within the tiled-and-fused code.
-void mlir::scf::yieldReplacementForFusedProducer(
+LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<scf::ForOp> loops) {
+    MutableArrayRef<LoopLikeOpInterface> loops) {
   if (loops.empty())
-    return;
+    return success();
 
   OpResult fusableProducer = fusedProducerInfo.origProducer;
   Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
@@ -679,15 +794,15 @@ void mlir::scf::yieldReplacementForFusedProducer(
       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
   if (succeeded(initValue)) {
 
-    auto newYieldValuesFn =
-        [&](RewriterBase &innerRewriter, Value iv,
-            ValueRange newRegionIterArgs) -> SmallVector<Value> {
+    NewYieldValuesFn newYieldValuesFn =
+        [&](OpBuilder &innerRewriter, Location loc,
+            ArrayRef<BlockArgument> newRegionIterArgs) -> SmallVector<Value> {
       OpBuilder::InsertionGuard g(innerRewriter);
       if (auto tiledDestStyleOp =
               tiledAndFusedProducer
                   .getDefiningOp<DestinationStyleOpInterface>()) {
         rewriter.setInsertionPoint(tiledDestStyleOp);
-        BlockArgument newRegionArg = loops.back().getRegionIterArgs().back();
+        BlockArgument newRegionArg = newRegionIterArgs.back();
         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
             sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
             sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
@@ -700,21 +815,22 @@ void mlir::scf::yieldReplacementForFusedProducer(
       rewriter.setInsertionPoint(block->getTerminator());
       Value replacement = rewriter.create<tensor::InsertSliceOp>(
           fusedProducerInfo.origProducer.getLoc(),
-          fusedProducerInfo.tiledAndFusedProducer,
-          loops.back().getRegionIterArgs().back(), sliceOp.getMixedOffsets(),
-          sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+          fusedProducerInfo.tiledAndFusedProducer, newRegionIterArgs.back(),
+          sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
+          sliceOp.getMixedStrides());
       return {replacement};
     };
 
-    addInitOperandsToLoopNest(rewriter, loops,
-                              SmallVector<Value>{initValue.value()},
-                              newYieldValuesFn);
+    return addInitOperandsToLoopNest(rewriter, loops,
+                                     SmallVector<Value>{initValue.value()},
+                                     newYieldValuesFn);
   }
+  return success();
 }
 
 /// Implementation of tile consumer and fuse producer greedily.
 FailureOr<scf::SCFTileAndFuseResult>
-mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
+mlir::scf::tileConsumerAndFuseProducersUsingSCF(
     RewriterBase &rewriter, TilingInterface consumer,
     const scf::SCFTileAndFuseOptions &options) {
   // This transformation is only valid for ops that return values (i.e. not
@@ -727,24 +843,25 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
   // 1. First tile the consumer.
   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
   llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
+
   FailureOr<scf::SCFTilingResult> tilingResult =
-      tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
+      tileUsingSCF(rewriter, consumer, options.tilingOptions);
+
   if (failed(tilingResult))
     return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
   for (auto *tiledOp : tilingResult->tiledOps)
     tiledAndFusedOps.insert(tiledOp);
-  SmallVector<scf::ForOp> forLoops =
-      castToTypedOperations<scf::ForOp>(tilingResult->loops);
 
   // If there are no loops generated, fusion is immaterial.
-  if (forLoops.empty()) {
+  auto &loops = tilingResult->loops;
+  if (loops.empty()) {
     DenseMap<Value, Value> replacements;
     for (auto [origVal, replacement] :
          llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
       replacements[origVal] = replacement;
     }
-    return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
-                                     getAsOperations(forLoops), replacements};
+    return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
+                                     replacements};
   }
 
   // To keep track of replacements for now just record the map from the original
@@ -780,7 +897,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     // Find the original producer of the slice.
     auto [fusableProducer, destinationInitArg] =
         getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
-                                          forLoops);
+                                          loops);
     if (!fusableProducer)
       continue;
 
@@ -793,15 +910,19 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
-        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
+        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
     if (!fusedResult)
       continue;
 
     if (yieldReplacement) {
-      yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
-                                       fusedResult.value(), forLoops);
+      if (failed(yieldReplacementForFusedProducer(
+              rewriter, candidateSliceOp, fusedResult.value(), loops))) {
+        return rewriter.notifyMatchFailure(
+            fusableProducer.getOwner(), "failed to replacement value for this "
+                                        "oepration from within the tiled loop");
+      }
       origValToResultNumber[fusableProducer] =
-          forLoops.front().getNumResults() - 1;
+          loops.front()->getNumResults() - 1;
     }
 
     if (Operation *tiledAndFusedOp =
@@ -814,123 +935,11 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
 
   DenseMap<Value, Value> replacements;
   for (auto [origVal, resultNumber] : origValToResultNumber) {
-    replacements[origVal] = forLoops.front()->getResult(resultNumber);
-  }
-
-  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
-                                   getAsOperations(forLoops), replacements};
-}
-
-//===----------------------------------------------------------------------===//
-// tileUsingSCFForAllOp implementation.
-//===----------------------------------------------------------------------===//
-
-FailureOr<scf::SCFTilingResult>
-mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
-                                const scf::SCFTilingOptions &options) {
-  Location loc = op->getLoc();
-  OpBuilder::InsertionGuard g(rewriter);
-
-  // 1. Get the range of loops that are represented by the operation.
-  SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
-  if (loopRanges.empty())
-    return op->emitOpError("expected non-empty loop ranges");
-  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
-  if (llvm::any_of(loopRanges, hasStrideOne))
-    return op->emitOpError("only stride-1 supported atm");
-
-  // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
-  // To make it easier, pad the tile sizes to loopRanges.size with value 0.
-  SmallVector<OpFoldResult> tileSizeVector =
-      options.tileSizeComputationFunction(rewriter, op);
-  tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
-
-  // 3. Build the offsets, sizes and steps for the tile and distributed loops.
-  SmallVector<OpFoldResult> lbs, ubs, steps;
-  for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
-    if (isConstantIntValue(tileSize, 0))
-      continue;
-    lbs.push_back(loopRange.offset);
-    ubs.push_back(loopRange.size);
-    steps.push_back(tileSize);
-  }
-
-  // 4. Gather destination tensors.
-  SmallVector<Value> dest;
-  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
-    return op->emitOpError("failed to get destination tensors");
-
-  // 5. Build the device mapping attribute.
-  std::optional<ArrayAttr> mappingAttr;
-  if (!options.mappingVector.empty()) {
-    mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
-  }
-
-  // 6. Create the ForallOp. We don't use the lambda body-builder
-  // version because we require the use of RewriterBase in the body, so we
-  // manually move the insertion point to the body below.
-  auto forallOp =
-      rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
-
-  // 7. Get the tile offset and sizes.
-  rewriter.setInsertionPoint(forallOp.getTerminator());
-  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
-  ValueRange ivs = forallOp.getInductionVars();
-  {
-    int materializedLoopNum = 0;
-    for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
-      if (isConstantIntValue(tileSize, 0)) {
-        tiledOffsets.push_back(loopRange.offset);
-        tiledSizes.push_back(loopRange.size);
-        continue;
-      }
-      Value iv = ivs[materializedLoopNum++];
-      tiledOffsets.push_back(iv);
-      tiledSizes.push_back(
-          getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
-    }
-  }
-
-  // 8. Tile the operation. Clone the operation to allow fix up of destination
-  // operands.
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
-  Operation *clonedOp =
-      cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
-  FailureOr<TilingResult> tilingResult =
-      cast<TilingInterface>(clonedOp).getTiledImplementation(
-          rewriter, tiledOffsets, tiledSizes);
-  if (failed(tilingResult))
-    return clonedOp->emitError("failed to tile op: ");
-  rewriter.eraseOp(clonedOp);
-
-  // 9. Parallel insert back into the result tensor.
-  for (auto [index, tiledValue, destBBArg] :
-       llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
-    // 9.a. Partial subset information is inserted just before the terminator.
-    rewriter.setInsertionPoint(forallOp.getTerminator());
-
-    SmallVector<OpFoldResult> resultOffsets, resultSizes;
-    if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
-                                        tiledSizes, resultOffsets,
-                                        resultSizes))) {
-      return op->emitOpError("output offsets couldn't be calculated");
-    }
-
-    SmallVector<OpFoldResult> strides(resultSizes.size(),
-                                      rewriter.getIndexAttr(1));
-    // 9.b. Parallel insertions are inserted at the end of the combining
-    // terminator.
-    rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
-    rewriter.create<tensor::ParallelInsertSliceOp>(
-        loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
+    replacements[origVal] = loops.front()->getResult(resultNumber);
   }
 
-  // 10. Return the tiling result.
-  return scf::SCFTilingResult{
-      tilingResult->tiledOps,
-      {forallOp.getOperation()},
-      llvm::map_to_vector(forallOp.getResults(),
-                          [](auto val) -> Value { return val; })};
+  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
+                                   replacements};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a2043c647d49a3..cdd85ddeb93add 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -933,11 +933,11 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
   fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
 
   // Map shared outs.
-  fusedMapping.map(target.getOutputBlockArguments(),
-                   fusedLoop.getOutputBlockArguments().slice(0, numTargetOuts));
+  fusedMapping.map(target.getRegionIterArgs(),
+                   fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
   fusedMapping.map(
-      source.getOutputBlockArguments(),
-      fusedLoop.getOutputBlockArguments().slice(numTargetOuts, numSourceOuts));
+      source.getRegionIterArgs(),
+      fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
 
   // Append everything except the terminator into the fused operation.
   rewriter.setInsertionPointToStart(fusedLoop.getBody());
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index be1316b95688bf..1e0e87b64e8113 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -63,8 +63,9 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
     return op->emitOpError("different number of inits and region iter_args: ")
            << loopLikeOp.getInits().size()
            << " != " << loopLikeOp.getRegionIterArgs().size();
-  if (loopLikeOp.getRegionIterArgs().size() !=
-      loopLikeOp.getYieldedValues().size())
+  if (!loopLikeOp.getYieldedValues().empty() &&
+      loopLikeOp.getRegionIterArgs().size() !=
+          loopLikeOp.getYieldedValues().size())
     return op->emitOpError(
                "different number of region iter_args and yielded values: ")
            << loopLikeOp.getRegionIterArgs().size()
@@ -78,21 +79,22 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
 
   // Verify types of inits/iter_args/yielded values/loop results.
   int64_t i = 0;
-  for (const auto it :
-       llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
-                       loopLikeOp.getYieldedValues())) {
-    if (std::get<0>(it).getType() != std::get<1>(it).getType())
-      return op->emitOpError(std::to_string(i))
-             << "-th init and " << i
-             << "-th region iter_arg have different type: "
-             << std::get<0>(it).getType()
-             << " != " << std::get<1>(it).getType();
-    if (std::get<1>(it).getType() != std::get<2>(it).getType())
-      return op->emitOpError(std::to_string(i))
-             << "-th region iter_arg and " << i
-             << "-th yielded value have different type: "
-             << std::get<1>(it).getType()
-             << " != " << std::get<2>(it).getType();
+  auto yieldedValues = loopLikeOp.getYieldedValues();
+  for (const auto [index, init, regionIterArg] :
+       llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) {
+    if (init.getType() != regionIterArg.getType())
+      return op->emitOpError(std::to_string(index))
+             << "-th init and " << index
+             << "-th region iter_arg have different type: " << init.getType()
+             << " != " << regionIterArg.getType();
+    if (!yieldedValues.empty()) {
+      if (regionIterArg.getType() != yieldedValues[index].getType())
+        return op->emitOpError(std::to_string(index))
+               << "-th region iter_arg and " << index
+               << "-th yielded value have different type: "
+               << regionIterArg.getType()
+               << " != " << yieldedValues[index].getType();
+    }
     ++i;
   }
   i = 0;
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
index 0e27c6a783e6f1..f0d4b790520e03 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
@@ -46,11 +46,11 @@ func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
 // CHECK:         %{{.+}} = scf.for %[[I:[a-zA-Z0-9]+]] =
-// CHECK:           %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]])
 // CHECK:           %{{.+}} = scf.for %[[J:[a-zA-Z0-9]+]] =
-// CHECK:             %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
-// CHECK:             %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
-// CHECK:             %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
+// CHECK-DAG:         %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]])
+// CHECK-DAG:         %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
+// CHECK-DAG:         %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
+// CHECK-DAG:         %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
 // CHECK:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
 // CHECK-SAME:          [%[[IN_I]], %[[IN_J]], 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
 // CHECK:             %[[ITER_SLICE:.+]] = tensor.extract_slice %{{[a-zA-Z0-9]+}}
diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index 4a940f12662e6c..c42bdbe982c4fa 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -30,8 +30,8 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:   %[[H:.*]] = memref.dim %[[ARG2]], %[[C0]]
 //   CHECK-DAG:   %[[W:.*]] = memref.dim %[[ARG2]], %[[C1]]
 //       CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[H]] step %[[C2]]
-//       CHECK:     %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[H]]]
 //       CHECK:     scf.for %[[J:.*]] = %[[C0]] to %[[W]] step %[[C3]]
+//   CHECK-DAG:     %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[H]]]
 //   CHECK-DAG:       %[[T5:.*]] = affine.min #[[MAP1]](%[[J]])[%[[W]]]
 //   CHECK-DAG:       %[[T6:.*]] = affine.apply #[[MAP2]](%[[T4]])[%[[KH]]]
 //   CHECK-DAG:       %[[T7:.*]] = affine.apply #[[MAP2]](%[[T5]])[%[[KW]]]
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index e8e63302286400..cdef71ded8b2ca 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -137,9 +137,9 @@ func.func @fold_extract_slice(
   //      CHECK:   %[[E:.*]] = tensor.extract_slice %[[ARG0]][3, 4] [%[[DIM]], 42] [1, 1] : tensor<?x128xf32> to tensor<?x42xf32>
 
   //      CHECK:    scf.for %[[IV0:[0-9a-zA-Z]*]] =
-  //      CHECK:      %[[SIZE0:.*]] = affine.min #[[MAP0]](%[[IV0]])[%[[DIM]]
   //      CHECK:      scf.for %[[IV1:[0-9a-zA-Z]*]] =
 
+  //      CHECK:      %[[SIZE0:.*]] = affine.min #[[MAP0]](%[[IV0]])[%[[DIM]]
   // Fold the existing extract slice op into the one created by the tiling.
   //      CHECK:        %[[T0:.*]] = tensor.extract_slice %[[E]]
   // CHECK-SAME:                                          %[[IV0]], %[[IV1]]
diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir
index 6bec1cbd65be68..1be5bf098c334c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir
@@ -177,8 +177,10 @@ module attributes {transform.with_named_sequence} {
     %pad = transform.get_producer_of_operand %matmul_padded[2]
       : (!transform.any_op) -> !transform.any_op
 
+    transform.apply_licm to %loops_l1#1 : !transform.any_op
+
     transform.structured.hoist_pad.build_packing_loop_nest %pad above %loops_l1#1
        : (!transform.any_op, !transform.any_op) -> !transform.any_op
-       transform.yield
+    transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
index b66db855035bcd..37cb9b2376fb43 100644
--- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
@@ -202,6 +202,8 @@ module attributes {transform.with_named_sequence} {
     %pad = transform.get_producer_of_operand %matmul_padded[2]
       : (!transform.any_op) -> !transform.op<"tensor.pad">
 
+    transform.apply_licm to %loops_l1#1 : !transform.any_op
+
     transform.structured.hoist_pad %pad by 1 loops
        : (!transform.op<"tensor.pad">) -> !transform.any_op
        transform.yield
diff --git a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize.mlir
index 762648050fdfdf..d54cace31efb99 100644
--- a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize.mlir
@@ -52,9 +52,9 @@ func.func @matmul(%A: tensor<1024x512xf32>,
 // CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write {{.*}} } : vector<8x[16]xi1> -> tensor<8x?xf32>
 // CHECK:           scf.yield %inserted_slice : tensor<1024x2000xf32>
 // CHECK:         }
-// CHECK:         scf.yield %7 : tensor<1024x2000xf32>
+// CHECK:         scf.yield {{.*}} : tensor<1024x2000xf32>
 // CHECK:       }
-// CHECK:       scf.yield %5 : tensor<1024x2000xf32>
+// CHECK:       scf.yield {{.*}} : tensor<1024x2000xf32>
 // CHECK-NEXT:    }
 
   %res = linalg.matmul ins(%A, %B: tensor<1024x512xf32>, tensor<512x2000xf32>)
diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
index bb42f84afc50f9..1afbd3d0504f74 100644
--- a/mlir/test/Dialect/Tensor/tiling.mlir
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -127,8 +127,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     else
 //       CHECK:       %[[SLICE:.*]] = tensor.extract_slice %[[IN]][0, {{.*}}] [7, {{.*}}] [1, 1]
 //       CHECK:       %[[PAD:.*]] = tensor.pad %[[SLICE]] low[3, %{{.*}}] high[5, {{.*}}]
-//       CHECK:     %[[CAST_SWAP_RESULT:.*]] = tensor.cast %[[SWAP_RESULT]] : tensor<?x?xf32> to tensor<15x?xf32>
-//       CHECK:     tensor.insert_slice %[[CAST_SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [15, {{.*}}] [1, 1]
+//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [15, {{.*}}] [1, 1]
 //       CHECK:   return %[[RESULT]]
 
 func.func @static_pad_tensor_0_3(%input_tensor: tensor<7x9xf32>,
@@ -158,15 +157,12 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:   %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C15]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
 //       CHECK:     %[[R2:.*]] = scf.if
 //       CHECK:       %[[GEN:.*]] = tensor.generate
-//       CHECK:       %[[cast_0:.*]] = tensor.cast %[[GEN]] : tensor<14x3xf32> to tensor<?x3xf32>
-//       CHECK:       scf.yield %[[cast_0]] : tensor<?x3xf32>
+//       CHECK:       scf.yield %[[GEN]] : tensor<14x3xf32>
 //       CHECK:     else
 //       CHECK:       %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
 //       CHECK:       %[[PAD:.*]] = tensor.pad %[[SLICE]] low[0, 0] high[7, %{{.*}}]
-//       CHECK:       %[[cast_1:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<?x3xf32>
-//       CHECK:       scf.yield %[[cast_1]] : tensor<?x3xf32>
-//       CHECK:     %[[cast:.*]] = tensor.cast %[[R2]] : tensor<?x3xf32> to tensor<14x3xf32>
-//       CHECK:     %[[R3:.*]] = tensor.insert_slice %[[cast]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
+//       CHECK:       scf.yield %[[PAD]] : tensor<14x3xf32>
+//       CHECK:     %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
 //       CHECK:     scf.yield %[[R3]] : tensor<14x15xf32>
 //       CHECK:   return %[[RESULT]] : tensor<14x15xf32>
 
@@ -312,8 +308,8 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:       %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor<?x?x8x2xf32>
 // CHECK-DAG:       %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor<?x?x8x2xf32>
 // CHECK:           %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<?x?x8x2xf32>) {
-// CHECK-DAG:         %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
 // CHECK:             %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<?x?x8x2xf32>) {
+// CHECK-DAG:           %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
 // CHECK-DAG:           %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]]
 // CHECK-DAG:           %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])
 // CHECK-DAG:           %[[IN_I_SZ:.*]] = affine.min #[[MAP3]]
@@ -364,11 +360,11 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:       %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor<?x?x?x?xf32>
 // CHECK-DAG:       %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor<?x?x?x?xf32>
 // CHECK:           %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<?x?x?x?xf32>) {
-// CHECK:             %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
 // CHECK:             %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<?x?x?x?xf32>) {
-// CHECK:               %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]]
-// CHECK:               %[[IN_D0:.*]] = tensor.dim %[[IN]], %[[C0]]
-// CHECK:               %[[IN_D1:.*]] = tensor.dim %[[IN]], %[[C1]]
+// CHECK-DAG:           %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
+// CHECK-DAG:           %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]]
+// CHECK-DAG:           %[[IN_D0:.*]] = tensor.dim %[[IN]], %[[C0]]
+// CHECK-DAG:           %[[IN_D1:.*]] = tensor.dim %[[IN]], %[[C1]]
 // CHECK:               %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])[%[[TILE_0]]]
 // CHECK:               %[[IN_I_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_I_SZ]], %[[I]])[%[[TILE_0]], %[[IN_D0]]]
 // CHECK:               %[[IN_J:.*]] = affine.apply #[[MAP2]](%[[J]])[%[[TILE_1]]]
@@ -550,8 +546,8 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:     %[[DIM_0:.+]] = tensor.dim %[[OUT]], %[[C0]]
 // CHECK-DAG:     %[[DIM_1:.+]] = tensor.dim %[[OUT]], %[[C1]]
 // CHECK:         %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[DIM_0]] step %[[C2]]
-// CHECK-DAG:       %[[OUT_K_SZ:.+]] = affine.min #[[MAP0]](%[[K]])[%[[DIM_0]]]
 // CHECK:           %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[DIM_1]] step %[[C4]]
+// CHECK-DAG:         %[[OUT_K_SZ:.+]] = affine.min #[[MAP0]](%[[K]])[%[[DIM_0]]]
 // CHECK-DAG:         %[[OUT_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]])[%[[DIM_1]]]
 // CHECK-DAG:         %[[IN_K:.+]] = affine.apply #[[MAP2]](%[[K]])
 // CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir
new file mode 100644
index 00000000000000..0bd2546e082b5a
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir
@@ -0,0 +1,176 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+
+func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %gemm : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_using_forall %matmul [10, 20]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @gemm_fill_fusion(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty
+//      CHECK:   scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+// CHECK-SAME:       shared_outs(%[[ITERARG0:.+]] = %[[INIT]])
+//  CHECK-DAG:     %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+//  CHECK-DAG:     %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+//  CHECK-DAG:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV0]], %[[IV1]]]
+//      CHECK:     %[[FILL_TILE:.+]] = linalg.fill
+// CHECK-SAME:         outs(%[[INIT_TILE]] :
+//      CHECK:     %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:         outs(%[[FILL_TILE]] :
+//      CHECK:     scf.forall.in_parallel {
+//      CHECK:       tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[ITERARG0]][%[[IV0]], %[[IV1]]]
+//      CHECK:     }
+
+// -----
+
+func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %generic = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<?x?xf32>
+  return %generic : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_using_forall %generic [10, 20]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @gemm_generic_fusion(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty
+//      CHECK:   scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+// CHECK-SAME:       shared_outs(%[[ITERARG0:.+]] = %[[INIT]])
+//  CHECK-DAG:     %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+//  CHECK-DAG:     %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+//  CHECK-DAG:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
+//      CHECK:     %[[FILL_TILE:.+]] = linalg.fill
+// CHECK-SAME:         outs(%[[INIT_TILE]] :
+//      CHECK:     %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:         outs(%[[FILL_TILE]] :
+//  CHECK-DAG:     %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]]
+//  CHECK-DAG:     %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV0]], %[[IV1]]]
+//      CHECK:     %[[GENERIC_TILE:.+]] = linalg.generic
+// CHECK-SAME:         ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
+// CHECK-SAME:         outs(%[[OUTS_TILE]] :
+//      CHECK:     scf.forall.in_parallel {
+//      CHECK:       tensor.parallel_insert_slice %[[GENERIC_TILE]] into %[[ITERARG0]][%[[IV0]], %[[IV1]]]
+//      CHECK:     }
+
+// -----
+
+func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_0 = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<30xf32>
+  %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
+  %2 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%arg0 : tensor<30x3xf32>) outs(%1 : tensor<30xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %8 = arith.maximumf %arg2, %arg1 : f32
+      linalg.yield %8 : f32
+    } -> tensor<30xf32>
+  %3 = tensor.empty() : tensor<30x3xf32>
+  %4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
+  %5:2 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
+                       affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%arg0, %2 : tensor<30x3xf32>, tensor<30xf32>) outs(%4, %3 : tensor<30xf32>, tensor<30x3xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
+      %8 = arith.subf %arg1, %arg2 : f32
+      %9 = math.exp %8 : f32
+      %10 = arith.addf %arg3, %9 : f32
+      linalg.yield %10, %9 : f32, f32
+    } -> (tensor<30xf32>, tensor<30x3xf32>)
+  %6 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
+                       affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%5#1, %5#0 : tensor<30x3xf32>, tensor<30xf32>) outs(%3 : tensor<30x3xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+      %8 = arith.divf %arg1, %arg2 : f32
+      linalg.yield %8 : f32
+    } -> tensor<30x3xf32>
+  return %6 : tensor<30x3xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %generic1, %generic2, %generic3 = transform.split_handle %generics
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_using_forall %generic3 [10]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//       CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>)
+//   CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<30xf32>
+//   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<30x3xf32>
+//       CHECK:   %[[RESULT:[a-zA-Z0-9]+]] = scf.forall (%[[IV:[a-zA-Z0-9]+]])
+//  CHECK-SAME:       shared_outs(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]])
+//   CHECK-DAG:     %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0]
+//   CHECK-DAG:     %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]]]
+//       CHECK:     %[[FILL0:.+]] = linalg.fill
+//  CHECK-SAME:         outs(%[[INIT0_SLICE]] :
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[ARG0_SLICE]] :
+//  CHECK-SAME:         outs(%[[FILL0]] :
+//       CHECK:     %[[FILL1:.+]] = linalg.fill
+//  CHECK-SAME:         outs(%[[INIT0_SLICE]] :
+//       CHECK:     %[[INIT1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
+//       CHECK:     %[[GENERIC1:.+]]:2 = linalg.generic
+//  CHECK-SAME:         ins(%[[ARG0_SLICE]], %[[GENERIC0]] :
+//  CHECK-SAME:         outs(%[[FILL1]], %[[INIT1_SLICE]] :
+//       CHECK:     %[[ITERARG0_SLICE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+//       CHECK:     %[[GENERIC2:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC1]]#1, %[[GENERIC1]]#0 :
+//  CHECK-SAME:         outs(%[[ITERARG0_SLICE]] :
+//       CHECK:     scf.forall.in_parallel {
+//       CHECK:       tensor.parallel_insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0]
+//       CHECK:     }
+//       CHECK:   return %[[RESULT]]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
index 05b7afdf0d1ca4..ba56206f03d767 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
@@ -79,7 +79,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     else
 //       CHECK:       %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
 //       CHECK:       %[[PAD:.*]] = tensor.pad %[[SLICE]] low[3, %{{.*}}] high[{{.*}}, {{.*}}]
-//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][%[[C0]], {{.*}}] [%[[DIM0]], {{.*}}] [1, 1]
+//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [%[[DIM0]], {{.*}}] [1, 1]
 //       CHECK:   return %[[RESULT]]
 
 // -----
@@ -143,7 +143,6 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-SAME:     %[[IN:.*]]: tensor<7x9xf32>
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
-//   CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
 //   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
 //       CHECK:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
 //       CHECK:     %[[SWAP_RESULT:.*]] = scf.if
@@ -151,7 +150,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     else
 //       CHECK:       %[[SLICE:.*]] = tensor.extract_slice %[[IN]][0, {{.*}}] [7, {{.*}}] [1, 1]
 //       CHECK:       %[[PAD:.*]] = tensor.pad %[[SLICE]] low[3, %{{.*}}] high[5, {{.*}}]
-//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][%[[C0]], {{.*}}] [%[[C15]], {{.*}}] [1, 1]
+//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [15, {{.*}}] [1, 1]
 //       CHECK:   return %[[RESULT]]
 
 /// Rest of the tests only check that they dont fail.
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 444232e9e1e2e1..607836faafb71d 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -30,10 +30,10 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
 //      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
-//  CHECK-DAG:     %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
 //  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
+//  CHECK-DAG:       %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
 //      CHECK:       %[[TS_X:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
 //  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
 // CHECK-SAME:           [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
@@ -82,13 +82,13 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
 //  CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
 //      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
-//  CHECK-DAG:     %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
 //  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
-//  CHECK-DAG:       %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
 //  CHECK-DAG:       %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
-//      CHECK:         %[[TS_K:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[K]]]
+//  CHECK-DAG:         %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
+//  CHECK-DAG:         %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
+//  CHECK-DAG:         %[[TS_K:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[K]]]
 //  CHECK-DAG:         %[[LHS_TILE:.+]] = memref.subview %[[ARG0]]
 // CHECK-SAME:             [%[[IV0]], %[[IV2]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
 //  CHECK-DAG:         %[[RHS_TILE:.+]] = memref.subview %[[ARG1]]
@@ -137,11 +137,11 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty()
 //       CHECK:   %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
 //  CHECK-SAME:       iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
-//   CHECK-DAG:     %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
 //   CHECK-DAG:     %[[C300:.+]] = arith.constant 300 : index
 //   CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //       CHECK:     %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
 //  CHECK-SAME:         iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
+//   CHECK-DAG:       %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
 //   CHECK-DAG:       %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
 //  CHECK-SAME:           [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
 //   CHECK-DAG:       %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
@@ -203,14 +203,14 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
 //      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[INIT]])
-//  CHECK-DAG:     %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]]
 //  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
-//  CHECK-DAG:       %[[TS_Q:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[Q]]]
 //  CHECK-DAG:       %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]]
 // CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
+//  CHECK-DAG:         %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]]
+//  CHECK-DAG:         %[[TS_Q:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[Q]]]
 //  CHECK-DAG:         %[[TS_C:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[C]]]
 //  CHECK-DAG:         %[[TS_H:.+]] = affine.apply #[[$MAP3]](%[[TS_P]])[%[[R]]]
 //  CHECK-DAG:         %[[TS_W:.+]] = affine.apply #[[$MAP4]](%[[TS_Q]])[%[[S]]]
@@ -302,14 +302,14 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
 //      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
-//  CHECK-DAG:     %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]]
 //  CHECK-DAG:     %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:     %[[INNER1:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
-//  CHECK-DAG:       %[[TS_K:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[K]]]
 //  CHECK-DAG:       %[[C10:.+]] = arith.constant 10 : index
 //      CHECK:       %[[INNER2:[a-zA-Z0-9]+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
 // CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
+//  CHECK-DAG:         %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]]
+//  CHECK-DAG:         %[[TS_K:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[K]]]
 //  CHECK-DAG:         %[[TS_M:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[M]]]
 //  CHECK-DAG:         %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
 // CHECK-SAME:             [%[[IV2]], %[[IV1]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
index db0c1327e2fe02..c5aff744b57ee6 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -2,7 +2,7 @@
 
 func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
     %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = linalg.matmul 
+  %0 = linalg.matmul
     ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
@@ -12,7 +12,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.tile_using_forall %matmul [10, 20] mapping [#gpu.block<y>, #gpu.block<x>]
+    %a, %b = transform.test.tile_using_forall %matmul [10, 20] mapping = [#gpu.block<y>, #gpu.block<x>]
       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
@@ -49,6 +49,48 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @simple_matmul_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+    %arg2 : memref<?x?xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.tile_using_forall %matmul [10, 20]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+//  CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+//      CHECK-LABEL: func.func @simple_matmul_memref(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//      CHECK:   scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = (0, 0) to (%[[M]], %[[N]]) step (10, 20) {
+//  CHECK-DAG:     %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
+//  CHECK-DAG:     %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
+//  CHECK-DAG:     %[[LHS_TILE:.+]] = memref.subview %[[ARG0]]
+// CHECK-SAME:         [%[[IV0]], 0] [%[[TS_M]], %[[K]]] [1, 1]
+//  CHECK-DAG:     %[[RHS_TILE:.+]] = memref.subview %[[ARG1]]
+// CHECK-SAME:         [0, %[[IV1]]] [%[[K]], %[[TS_N]]] [1, 1]
+//  CHECK-DAG:     %[[OUT_TILE:.+]] = memref.subview %[[ARG2]]
+// CHECK-SAME:         [%[[IV0]], %[[IV1]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:     linalg.matmul
+// CHECK-SAME:             ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:             outs(%[[OUT_TILE]] :
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
@@ -203,3 +245,107 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:   %[[INDEX1:.+]] = linalg.index 1
 //       CHECK:   %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
 //       CHECK:   arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
+
+// -----
+
+func.func @interchange_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.tile_using_forall %matmul [10, 20] interchange = [1, 0] mapping = [#gpu.block<y>, #gpu.block<x>]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+//  CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+//      CHECK-LABEL: func.func @interchange_matmul(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]
+// CHECK-SAME:       (0, 0) to (%[[N]], %[[M]]) step (20, 10)
+// CHECK-SAME:       shared_outs(%[[INIT0:.+]] = %[[ARG2]])
+//  CHECK-DAG:     %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]]
+//  CHECK-DAG:     %[[TS_M:.+]] = affine.min #[[$MAP2]](%[[IV1]])[%[[M]]]
+//  CHECK-DAG:     %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:         [%[[IV1]], 0] [%[[TS_M]], %[[K]]] [1, 1]
+//  CHECK-DAG:     %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME:         [0, %[[IV0]]] [%[[K]], %[[TS_N]]] [1, 1]
+//  CHECK-DAG:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT0]]
+// CHECK-SAME:         [%[[IV1]], %[[IV0]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:     %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:         outs(%[[INIT_TILE]] :
+//      CHECK:     scf.forall.in_parallel {
+//      CHECK:       tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT0]]
+// CHECK-SAME:           [%[[IV1]], %[[IV0]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:     } {mapping = [#gpu.block<y>, #gpu.block<x>]}
+//      CHECK:   return %[[OUTER]]
+
+// -----
+
+func.func @check_scalar_operation(%arg0 : tensor<f32>) -> tensor<f32> {
+  %init = tensor.empty() : tensor<f32>
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+      iterator_types = []}      
+      ins(%arg0 : tensor<f32>) outs(%init : tensor<f32>){
+    ^bb0(%b0 : f32, %b1 : f32):
+      %1 = arith.mulf %b0, %b0 : f32
+      linalg.yield %1 : f32
+  } -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a = transform.test.tile_using_forall %generic []
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @check_scalar_operation
+//   CHECK-NOT:   scf.for
+//       CHECK:   linalg.generic
+
+// -----
+
+func.func @check_scalar_memref_operation(%arg0 : memref<f32>, %arg1 : memref<f32>){
+  linalg.generic {
+      indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+      iterator_types = []}      
+      ins(%arg0 : memref<f32>) outs(%arg1 : memref<f32>){
+    ^bb0(%b0 : f32, %b1 : f32):
+      %1 = arith.mulf %b0, %b0 : f32
+      linalg.yield %1 : f32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a = transform.test.tile_using_forall %generic []
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @check_scalar_memref_operation
+//   CHECK-NOT:   scf.for
+//       CHECK:   linalg.generic
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index cc450f45649516..d22d73c8ce31b6 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -54,12 +54,10 @@ static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
 /// Apply a tile and fuse transformation to all payload ops and store both the
 /// tiled operation as well as the created tile loops.
 template <typename Range>
-static LogicalResult
-applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
-                      Range &&payloadOps, unsigned numLoops,
-                      ArrayRef<OpFoldResult> tileSizes,
-                      ArrayRef<int64_t> interchange,
-                      transform::TransformResults &transformResults) {
+static LogicalResult applyTileAndFuseToAll(
+    RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
+    unsigned numLoops, ArrayRef<OpFoldResult> tileSizes,
+    ArrayRef<int64_t> interchange, TransformResults &transformResults) {
   SmallVector<Operation *> tiledOps;
   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
 
@@ -97,8 +95,8 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
 
     rewriter.setInsertionPoint(target);
     FailureOr<scf::SCFTileAndFuseResult> tiledResults =
-        scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
-            rewriter, tilingInterfaceOp, tileAndFuseOptions);
+        scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
+                                                  tileAndFuseOptions);
     if (failed(tiledResults))
       return failure();
 
@@ -109,12 +107,11 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
       for (OpResult res : toReplace->getResults())
         if (auto replacement = tiledResults->replacements.lookup(res)) {
           Operation *replacementOp = replacement.getDefiningOp();
-          rewriter.replaceUsesWithIf(
-              res, replacement, [&](mlir::OpOperand &use) {
-                Operation *user = use.getOwner();
-                return dominanceInfo.properlyDominates(replacementOp, user) &&
-                       user->getParentOp() == replacementOp->getParentOp();
-              });
+          rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
+            Operation *user = use.getOwner();
+            return dominanceInfo.properlyDominates(replacementOp, user) &&
+                   user->getParentOp() == replacementOp->getParentOp();
+          });
         }
 
       if (toReplace->use_empty()) {
@@ -138,10 +135,10 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
   return success();
 }
 
-DiagnosedSilenceableFailure transform::TestFuseAndYieldOp::apply(
-    transform::TransformRewriter &rewriter,
-    mlir::transform::TransformResults &transformResults,
-    mlir::transform::TransformState &state) {
+DiagnosedSilenceableFailure
+transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
+                                     TransformResults &transformResults,
+                                     TransformState &state) {
   SmallVector<int64_t> tileSizes =
       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
   SmallVector<int64_t> tileInterchange =
@@ -169,7 +166,7 @@ static LogicalResult
 applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
                Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
                ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
-               transform::TransformResults &transformResults) {
+               TransformResults &transformResults) {
   SmallVector<Operation *> tiledOps;
   SmallVector<Operation *> loopOps;
 
@@ -186,10 +183,11 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
           });
       tilingOptions.setMapping(mappingAttrs);
     }
+    tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
 
     rewriter.setInsertionPoint(target);
     FailureOr<scf::SCFTilingResult> tiledResults =
-        scf::tileUsingSCFForallOp(rewriter, tilingInterfaceOp, tilingOptions);
+        scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
     if (failed(tiledResults))
       return failure();
 
@@ -209,10 +207,10 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
   return success();
 }
 
-DiagnosedSilenceableFailure transform::TestTileUsingForallOp::apply(
-    transform::TransformRewriter &rewriter,
-    mlir::transform::TransformResults &transformResults,
-    mlir::transform::TransformState &state) {
+DiagnosedSilenceableFailure
+transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
   SmallVector<int64_t> tileSizes =
       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
   SmallVector<int64_t> interchange =
@@ -235,6 +233,96 @@ void transform::TestTileUsingForallOp::getEffects(
   modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// TestFuseUsingForallOp
+//===----------------------------------------------------------------------===//
+
+/// Apply a tiling transformation to all payload ops and store both the
+/// tiled operation as well as the created tile loops.
+template <typename Range>
+static LogicalResult applyTilingToAll(
+    RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
+    unsigned numLoops, TransformResults &transformResults,
+    function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
+        applyFn) {
+  SmallVector<Operation *> tiledLinalgOps;
+  SmallVector<SmallVector<Operation *>> loopOps(1);
+
+  for (Operation *target : payloadOps) {
+    auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+    if (!tilingInterfaceOp)
+      return transformOp->emitError("only TilingInterface ops are supported");
+
+    rewriter.setInsertionPoint(target);
+    FailureOr<scf::SCFTileAndFuseResult> tiledResults =
+        applyFn(tilingInterfaceOp);
+    if (failed(tiledResults))
+      return failure();
+
+    // Perform the replacement of tiled and fused values.
+    SmallVector<Operation *> opsToReplace{target};
+    llvm::append_range(opsToReplace, tiledResults->fusedProducers);
+    for (Operation *toReplace : opsToReplace) {
+      for (OpResult res : toReplace->getResults())
+        if (auto replacement = tiledResults->replacements.lookup(res))
+          rewriter.replaceAllUsesWith(res, replacement);
+      if (toReplace->use_empty())
+        rewriter.eraseOp(toReplace);
+    }
+
+    // Report back the relevant handles to the transform op.
+    tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
+    assert(tiledResults->loops.size() == 1 &&
+           cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
+           "Mismatched number of loops, tile and fuse transform should have "
+           "failed");
+    loopOps[0].push_back({tiledResults->loops[0]});
+  }
+
+  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
+  if (!loopOps.empty())
+    transformResults.set(transformOp->getOpResult(1), loopOps[0]);
+
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
+  SmallVector<int64_t> tileSizes =
+      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
+  SmallVector<int64_t> tileInterchange =
+      extractFromIntegerArrayAttr<int64_t>(getInterchange());
+
+  scf::SCFTilingOptions tilingOptions;
+  tilingOptions.interchangeVector = tileInterchange;
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
+  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
+  tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
+  scf::SCFTileAndFuseOptions tileAndFuseOptions;
+  tileAndFuseOptions.tilingOptions = tilingOptions;
+  LogicalResult result = applyTilingToAll(
+      rewriter, getOperation(), state.getPayloadOps(getRootOp()),
+      tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
+      [&](TilingInterface tilingInterfaceOp)
+          -> FailureOr<scf::SCFTileAndFuseResult> {
+        return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
+                                                    tileAndFuseOptions);
+      });
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseUsingForallOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getRootOp(), effects);
+  producesHandle(getTiledOps(), effects);
+  producesHandle(getLoops(), effects);
+  modifiesPayload(effects);
+}
+
 #define GET_OP_CLASSES
 #include "TestTilingInterfaceTransformOps.cpp.inc"
 
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 6e9354198896ab..17965fee4fbd13 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -71,11 +71,33 @@ def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall",
                       Variadic<TransformHandleTypeInterface>:$loops);
 
   let assemblyFormat = [{
-    $target ($tile_sizes^)? (`interchange` $interchange^)?
-    (`mapping` $mapping^)?
+    $target ($tile_sizes^)? (`interchange` `=` $interchange^)?
+    (`mapping` `=` $mapping^)?
     attr-dict `:` functional-type(operands, results)
   }];
 }
 
+def TestFuseUsingForallOp : Op<Transform_Dialect, "test.fuse_using_forall",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Test operation to tile the operation pointed to by the target handle and
+    fuses their producers greedily using the options provided as attributes.
+    This operation uses scf.forall for the loop construct.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$root_op,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+                   DefaultValuedOptionalAttr<I64ArrayAttr, "{}">:$interchange,
+                   OptionalAttr<DeviceMappingArrayAttr>:$mapping);
+  let results = (outs TransformHandleTypeInterface:$tiled_ops,
+                      Variadic<TransformHandleTypeInterface>:$loops);
+
+  let assemblyFormat = [{
+    $root_op ($tile_sizes^)? (`interchange` $interchange^)?
+    (`mapping` `=` $mapping^)?
+    attr-dict `:` functional-type(operands, results)
+  }];
+}
 
 #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS



More information about the Mlir-commits mailing list