[flang-commits] [flang] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)

via flang-commits flang-commits at lists.llvm.org
Fri Jan 19 17:19:17 PST 2024


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/77874

>From 391d9a7f7c9045666fbd23f25c34e03d2be7cb5a Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 4 Jan 2024 18:50:09 -0800
Subject: [PATCH 01/12] [mlir][TilingInterface] Use `LoopLikeOpInterface` in
 tiling using SCF to unify tiling with `scf.for` and `scf.forall`.

Using `LoopLikeOpInterface` as the basis for the implementation
unifies all the tiling logic for both `scf.for` and `scf.forall`. The
only difference is the actual loop generation.
Instead of many entry points for each loop type, the loop type is now
passed as part of the options passed to the tiling method.

This is a breaking change with the following changes

1) The `scf::tileUsingSCFForOp` is renamed to `scf::tileUsingSCF`
2) The `scf::tileUsingSCFForallOp` is deprecated. The same
   functionality is obtained by using `scf::tileUsingSCF` and setting
   the loop type in `scf::SCFTilingOptions` passed into this method to
   `scf::SCFTilingOptions::LoopType::ForallOp` (using the
   `setLoopType` method).
3) The `scf::tileConsumerAndFusedProducerGreedilyUsingSCFForOp` is
   renamed to `scf::tileConsumerAndFuseProducerUsingSCF`. The use of
   the `controlFn` in `scf::SCFTileAndFuseOptions` allows implementing
   any strategy with the default callback implemeting the greedy fusion.
4) The `scf::SCFTilingResult` and `scf::SCFTileAndFuseResult` now use
   `SmallVector<LoopLikeOpInterface>`.
5) To make `scf::ForallOp` implement the parts of
   `LoopLikeOpInterface` needed, the `getOutputBlockArguments()`
   method is replaced with `getRegionIterArgs()`

This change also introduces a new interface method for
`LoopLikeOpInterface`, that allows loop constructs to handle tiled
yields.

These changes now bring the tiling and fusion capabilities using
`scf.forall` on par with what was already supported
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  21 +-
 .../SCF/Transforms/TileUsingInterface.h       |  41 +-
 .../mlir/Interfaces/LoopLikeInterface.h       |  22 +
 .../mlir/Interfaces/LoopLikeInterface.td      |  13 +
 .../TransformOps/LinalgTransformOps.cpp       |  10 +-
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp |   4 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 126 ++-
 .../SCF/Transforms/TileUsingInterface.cpp     | 806 +++++++++---------
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |   8 +-
 mlir/lib/Interfaces/LoopLikeInterface.cpp     |  36 +-
 .../Linalg/generalize-tensor-unpack-tile.mlir |   8 +-
 mlir/test/Dialect/Linalg/tile-conv.mlir       |   2 +-
 mlir/test/Dialect/Linalg/tile-tensors.mlir    |   2 +-
 ...-op-hoist-pad-build-packing-loop-nest.mlir |   4 +-
 .../Linalg/transform-op-hoist-pad.mlir        |   2 +
 .../transform-op-peel-and-vectorize.mlir      |   4 +-
 mlir/test/Dialect/Tensor/tiling.mlir          |  24 +-
 .../tile-and-fuse-using-scfforall.mlir        | 176 ++++
 .../tile-fuse-and-yield-using-interface.mlir  |   1 +
 .../tile-fuse-and-yield-using-scfforall.mlir  |  60 ++
 .../tile-pad-using-interface.mlir             |   5 +-
 .../TilingInterface/tile-using-interface.mlir |  18 +-
 .../TilingInterface/tile-using-scfforall.mlir | 150 +++-
 .../TestTilingInterfaceTransformOps.cpp       | 135 ++-
 .../TestTilingInterfaceTransformOps.td        |  32 +-
 25 files changed, 1176 insertions(+), 534 deletions(-)
 create mode 100644 mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir
 create mode 100644 mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8d65d3dd820baf..08caaa0b880b45 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -135,10 +135,10 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
-       ["getInitsMutable", "getSingleInductionVar", "getSingleLowerBound",
-        "getSingleStep", "getSingleUpperBound", "getYieldedValuesMutable",
-        "getLoopResults", "promoteIfSingleIteration",
-        "replaceWithAdditionalYields"]>,
+       ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
+        "getSingleLowerBound", "getSingleStep", "getSingleUpperBound",
+        "getYieldedValuesMutable", "getLoopResults", "promoteIfSingleIteration",
+        "replaceWithAdditionalYields", "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -259,10 +259,6 @@ def ForOp : SCF_Op<"for",
 
     Value getInductionVar() { return getBody()->getArgument(0); }
 
-    Block::BlockArgListType getRegionIterArgs() {
-      return getBody()->getArguments().drop_front(getNumInductionVars());
-    }
-
     /// Return the `index`-th region iteration argument.
     BlockArgument getRegionIterArg(unsigned index) {
       assert(index < getNumRegionIterArgs() &&
@@ -304,8 +300,9 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["promoteIfSingleIteration", "getSingleInductionVar",
-          "getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
+          ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
+           "getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
+           "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -585,10 +582,6 @@ def ForallOp : SCF_Op<"forall", [
                                     getNumDynamicControlOperands() + getRank());
     }
 
-    ArrayRef<BlockArgument> getOutputBlockArguments() {
-      return getBody()->getArguments().drop_front(getRank());
-    }
-
     ::mlir::ValueRange getInductionVars() {
       return getBody()->getArguments().take_front(getRank());
     }
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 5d2d78e6e6165b..965ef9e203be28 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
 
 #include <deque>
@@ -52,6 +53,14 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  /// Specify which loop construct to use for tile and fuse.
+  enum class LoopType { ForOp, ForallOp };
+  LoopType loopType = LoopType::ForOp;
+  SCFTilingOptions &setLoopType(LoopType type) {
+    loopType = type;
+    return *this;
+  }
+
   /// Specify mapping of loops to devices. This is only respected when the loop
   /// constructs support such a mapping (like `scf.forall`). Will be ignored
   /// when using loop constructs that dont support such a mapping (like
@@ -71,7 +80,7 @@ struct SCFTilingResult {
   /// of the last op.
   SmallVector<Operation *> tiledOps;
   /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<Operation *> loops;
+  SmallVector<LoopLikeOpInterface> loops;
   /// Values to use as replacements for the untiled op. Is the same size as the
   /// number of results of the untiled op.
   SmallVector<Value> replacements;
@@ -79,15 +88,9 @@ struct SCFTilingResult {
 
 /// Method to tile an op that implements the `TilingInterface` using
 /// `scf.for` for iterating over the tiles.
-FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter,
-                                             TilingInterface op,
-                                             const SCFTilingOptions &options);
-
-/// Method to tile an op that implements the `TilingInterface` using
-/// `scf.forall`.
-FailureOr<SCFTilingResult>
-tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
-                     const SCFTilingOptions &options);
+FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
+                                        TilingInterface op,
+                                        const SCFTilingOptions &options);
 
 /// Options used to control tile + fuse.
 struct SCFTileAndFuseOptions {
@@ -135,7 +138,7 @@ struct SCFFuseProducerOfSliceResult {
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                            tensor::ExtractSliceOp candidateSliceOp,
-                           MutableArrayRef<scf::ForOp> loops);
+                           MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Reconstruct the fused producer from within the tiled-and-fused code. Based
 /// on the slice of the producer computed in place it is possible that within
@@ -187,10 +190,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 /// where `%0` had other uses as well. If not reconstructed from within the loop
 /// body, uses of `%0` could not be replaced, making it still live and the
 /// fusion immaterial.
-void yieldReplacementForFusedProducer(
+LogicalResult yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<scf::ForOp> loops);
+    MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Transformation information returned after tile and fuse.
 struct SCFTileAndFuseResult {
@@ -201,7 +204,7 @@ struct SCFTileAndFuseResult {
   /// generated operation.
   llvm::SetVector<Operation *> tiledAndFusedOps;
   /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<Operation *> loops;
+  SmallVector<LoopLikeOpInterface> loops;
   /// The replacement values to use for the tiled and fused operations.
   llvm::DenseMap<Value, Value> replacements;
 };
@@ -232,9 +235,9 @@ struct SCFTileAndFuseResult {
 /// }
 /// ```
 FailureOr<SCFTileAndFuseResult>
-tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
-    RewriterBase &rewriter, TilingInterface consumer,
-    const SCFTileAndFuseOptions &options);
+tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
+                                     TilingInterface consumer,
+                                     const SCFTileAndFuseOptions &options);
 
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
@@ -249,8 +252,8 @@ struct SCFReductionTilingResult {
   Operation *mergeOp;
   /// Initial op
   Operation *initialOp;
-  /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<scf::ForOp> loops;
+  /// The loop operations that iterate over the tiles.
+  SmallVector<LoopLikeOpInterface> loops;
 };
 
 /// Method to tile a reduction and generate a parallel op within a serial loop.
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 7c7d378d0590ab..c62476f9b62256 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -25,6 +25,28 @@ class RewriterBase;
 using NewYieldValuesFn = std::function<SmallVector<Value>(
     OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
 
+/// A function that allows returning additional yielded values during
+/// `yieldTiledValuesAndReplace`.
+/// - `ivs` induction variable for the loop.
+/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
+/// - `tiledValues` the tiled values to return. Must be of same size as
+///   `newbbArgs`, each element of this array is inserted into the corresponding
+///   element in `newbbArgs`.
+/// - `resultOffsets` is of the same size as `tiledValues` and represents
+///   the offsets to use when inserting corresponding element from `tiledValues`
+///   into the element from `newBbArgs`.
+/// - `resultSizes` is of the same size as `tiledValues` and represents
+///   the size of the corresponding element from `tiledValues` inserted into
+///   the element from `newBbArgs`.
+/// - `resultStrides` is of the same size as `tiledValues` and represents
+///   the strides to use when inserting corresponding element from `tiledValues`
+///   into the element from `newBbArgs`.
+using YieldTiledValuesFn = std::function<LogicalResult(
+    RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
+    SmallVector<Value> &tiledValues,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
+
 namespace detail {
 /// Verify invariants of the LoopLikeOpInterface.
 LogicalResult verifyLoopLikeOpInterface(Operation *op);
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 75d90b67bd82f3..20afc35571fbf2 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -218,6 +218,19 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         return ::mlir::failure();
       }]
     >,
+    InterfaceMethod<[{
+        TODO
+      }],
+      /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
+      /*methodName=*/"yieldTiledValuesAndReplace",
+      /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
+                    "::mlir::ValueRange":$newInitOperands,
+                    "const ::mlir::YieldTiledValuesFn &":$yieldTiledValuesFn),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 97d2b4a3be5c56..5efa2b3eb7476f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -485,8 +485,8 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
       tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
       [&](TilingInterface tilingInterfaceOp)
           -> FailureOr<scf::SCFTileAndFuseResult> {
-        return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
-            rewriter, tilingInterfaceOp, tileAndFuseOptions);
+        return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
+                                                    tileAndFuseOptions);
       });
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
@@ -584,7 +584,7 @@ static Operation *replaceForAllWithNewSignature(
   Operation *firstYieldOp = yieldingOps.front();
   rewriter.setInsertionPoint(firstYieldOp);
   Value src = tileAndFuseResult.tiledValues[0];
-  Value dst = newforallOp.getOutputBlockArguments().back();
+  Value dst = newforallOp.getRegionIterArgs().back();
   SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
   rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
                                                  dst, offsets, sizes, strides);
@@ -2063,7 +2063,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
   });
   SmallVector<int64_t> emptyTileSizes;
   rewriter.setInsertionPoint(target);
-  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
+  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
       rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
   if (failed(maybeTilingResult))
     return emitDefaultDefiniteFailure(target);
@@ -2647,7 +2647,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
 
     tilingOptions.setInterchange(getInterchange());
     FailureOr<scf::SCFTilingResult> maybeTilingResult =
-        tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
+        tileUsingSCF(rewriter, tilingInterface, tilingOptions);
     if (failed(maybeTilingResult))
       return DiagnosedSilenceableFailure::definiteFailure();
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7f3ab1f1a24b2f..339d06cdeaf60a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -358,7 +358,7 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
 
   // 3. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForallOp.
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
   Operation *tiledOp = nullptr;
   SmallVector<Value> tiledValues;
   {
@@ -695,7 +695,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   // 4. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForallOp.
   SmallVector<Value> tilingResults;
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
   {
     // 4.a. RAII guard, inserting within forallOp, before terminator.
     OpBuilder::InsertionGuard g(b);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5570c2ec688c8a..31101861ad6f45 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -527,6 +527,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
 
 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
 
+Block::BlockArgListType ForOp::getRegionIterArgs() {
+  return getBody()->getArguments().drop_front(getNumInductionVars());
+}
+
 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
   return getInitArgsMutable();
 }
@@ -584,6 +588,63 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
   return cast<LoopLikeOpInterface>(newLoop.getOperation());
 }
 
+FailureOr<LoopLikeOpInterface> ForOp::yieldTiledValuesAndReplace(
+    RewriterBase &rewriter, ValueRange newInitOperands,
+    const YieldTiledValuesFn &yieldTiledValuesFn) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(getOperation());
+
+  auto inits = llvm::to_vector(getInitArgs());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+      getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
+      [](OpBuilder &, Location, Value, ValueRange) {});
+
+  // Move the loop body to the new op.
+  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
+                       newLoop.getBody()->getArguments().take_front(
+                           getBody()->getNumArguments()));
+
+  auto yieldOp = cast<scf::YieldOp>(newLoop.getBody()->getTerminator());
+  rewriter.setInsertionPoint(yieldOp);
+
+  SmallVector<Value> tiledValues;
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+  ValueRange newRegionIterArgs =
+      newLoop.getRegionIterArgs().take_back(newInitOperands.size());
+  if (failed(yieldTiledValuesFn(rewriter, getLoc(), newLoop.getInductionVar(),
+                                newRegionIterArgs, tiledValues, resultOffsets,
+                                resultSizes))) {
+    return rewriter.notifyMatchFailure(getOperation(),
+                                       "failed to get tiled values");
+  }
+
+  if (tiledValues.size() != resultOffsets.size() ||
+      tiledValues.size() != resultSizes.size()) {
+    return rewriter.notifyMatchFailure(
+        getOperation(),
+        "expected number of tiled values returned, the number of offset "
+        "vectors and number of size vectors to be the same");
+  }
+
+  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
+  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
+       llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
+                       resultSizes)) {
+    SmallVector<OpFoldResult> resultStride(resultOffset.size(),
+                                           rewriter.getIndexAttr(1));
+    Value insert = rewriter.create<tensor::InsertSliceOp>(
+        yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
+        resultStride);
+    newYieldValues.push_back(insert);
+  }
+
+  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
+  rewriter.replaceOp(getOperation(),
+                     newLoop->getResults().take_front(getNumResults()));
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
 ForOp mlir::scf::getForInductionVarOwner(Value val) {
   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
   if (!ivArg)
@@ -622,6 +683,61 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
   return success();
 }
 
+Block::BlockArgListType ForallOp::getRegionIterArgs() {
+  return getBody()->getArguments().drop_front(getRank());
+}
+
+MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
+  return getOutputsMutable();
+}
+
+FailureOr<LoopLikeOpInterface> ForallOp::yieldTiledValuesAndReplace(
+    RewriterBase &rewriter, ValueRange newInitOperands,
+    const YieldTiledValuesFn &yieldTiledValuesFn) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(getOperation());
+  auto inits = llvm::to_vector(getOutputs());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  auto newLoop = rewriter.create<scf::ForallOp>(
+      getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
+      inits, getMapping(), [](OpBuilder &, Location, ValueRange) {});
+
+  // Move the region of the current block to the newly created op.
+  Block *newLoopBody = newLoop.getBody();
+  rewriter.mergeBlocks(
+      getBody(), newLoopBody,
+      newLoopBody->getArguments().take_front(getBody()->getNumArguments()));
+
+  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
+  rewriter.setInsertionPoint(terminator);
+  SmallVector<Value> tiledValues;
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+  ValueRange regionIterArgs =
+      newLoop.getRegionIterArgs().take_back(newInitOperands.size());
+  if (failed(yieldTiledValuesFn(rewriter, getLoc(), newLoop.getInductionVars(),
+                                regionIterArgs, tiledValues, resultOffsets,
+                                resultSizes))) {
+    return rewriter.notifyMatchFailure(getOperation(),
+                                       "failed to get yielded tiled values");
+  }
+
+  // Update the terminator.
+  rewriter.setInsertionPointToEnd(terminator.getBody());
+
+  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
+           tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> resultStride(resultOffset.size(),
+                                           rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
+        resultStride);
+  }
+
+  rewriter.replaceOp(getOperation(),
+                     newLoop->getResults().take_front(getNumResults()));
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
 /// Promotes the loop body of a scf::ForallOp to its containing block.
 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
   OpBuilder::InsertionGuard g(rewriter);
@@ -1355,11 +1471,6 @@ void ForallOp::build(
     return;
   }
   bodyBuilderFn(b, result.location, bodyBlock.getArguments());
-#ifndef NDEBUG
-  auto terminator = llvm::dyn_cast<InParallelOp>(bodyBlock.getTerminator());
-  assert(terminator &&
-         "expected bodyBuilderFn to create InParallelOp terminator");
-#endif // NDEBUG
 }
 
 // Builder that takes loop bounds.
@@ -1630,9 +1741,8 @@ struct FoldTensorCastOfOutputIntoForallOp
     // mapped to the tensor.cast old-typed results of the output bbArgs. The
     // destination have to be updated to point to the output bbArgs directly.
     auto terminator = newForallOp.getTerminator();
-    for (auto [yieldingOp, outputBlockArg] :
-         llvm::zip(terminator.getYieldingOps(),
-                   newForallOp.getOutputBlockArguments())) {
+    for (auto [yieldingOp, outputBlockArg] : llvm::zip(
+             terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
       auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
       insertSliceOp.getDestMutable().assign(outputBlockArg);
     }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 38e0625d7ce093..50a85e6e34e240 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -55,32 +55,8 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
   return filledVector;
 }
 
-/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
-template <typename SrcOpTy>
-static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
-  return llvm::to_vector(
-      llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
-}
-template <typename SrcOpTy>
-static SmallVector<Operation *>
-getAsOperations(const SmallVector<SrcOpTy> &ops) {
-  return getAsOperations(ArrayRef<SrcOpTy>(ops));
-}
-
-/// Convert a list of `Operation *` to a list of `DstOpTy.
-template <typename DstOpTy>
-static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
-  return llvm::to_vector(
-      llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
-}
-template <typename DstOpTy>
-static SmallVector<DstOpTy>
-castToTypedOperations(const SmallVector<Operation *> &ops) {
-  return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
-}
-
 //===----------------------------------------------------------------------===//
-// tileUsingSCFForOp implementation.
+// tileUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
 // Check if `stride` evenly divides the trip count `size - offset`.
@@ -135,66 +111,180 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
   return clonedOp;
 }
 
-/// Generate an empty loop nest that represents the tiled loop nest shell.
+/// Generate the tile-loop nest using `scf.for` operation.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
-/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
-///   the tile processed within the inner most loop.
-/// Note that this methods adds `scf.yield` operation for all but the innermost
-/// loop. These yield the value returned by the immediately inner loop. The
-/// caller is expected to add the scf.yield operation for the innermost loop.
-static SmallVector<scf::ForOp> generateTileLoopNest(
-    OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
-    SmallVector<OpFoldResult> &sizes, ValueRange destinationTensors = {}) {
-  if (loopRanges.empty())
-    return {};
+/// - `destinationTensors` are the init values to use for the outer most loop.
+/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
+/// most
+///    loop.
+/// - `loops` is an in-out parameter into which the generated loops are
+///    populated.
+static LogicalResult generateLoopNestUsingForOp(
+    RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+    ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
+    YieldTiledValuesFn yieldTiledValuesFn,
+    SmallVector<LoopLikeOpInterface> &loops) {
+  assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
          "expected as many tile sizes as loop ranges");
-  OpBuilder::InsertionGuard guard(builder);
-  SmallVector<scf::ForOp> loops;
-  offsets.resize(loopRanges.size());
-  sizes.resize(loopRanges.size());
-
-  for (auto loopRange : llvm::enumerate(loopRanges)) {
-    Value offset =
-        getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
-    Value size =
-        getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
-    Value tileSize = getValueOrCreateConstantIndexOp(
-        builder, loc, tileSizes[loopRange.index()]);
+  OpBuilder::InsertionGuard guard(rewriter);
+  SmallVector<Value> ivs;
+
+  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
     // No loops if tile size is zero. Set offset and size to the loop
     // offset and size.
-    if (matchPattern(tileSize, m_Zero())) {
-      offsets[loopRange.index()] = offset;
-      sizes[loopRange.index()] = size;
+    if (isConstantIntValue(tileSize, 0))
       continue;
-    }
 
-    auto loop = builder.create<scf::ForOp>(
-        loc, offset, size, tileSize, destinationTensors,
-        [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
-            ValueRange /*iterArgs*/) {
-          sizes[loopRange.index()] =
-              getBoundedTileSize(bodyBuilder, bodyLoc, loopRange.value(), iv,
-                                 getAsOpFoldResult(tileSize));
-        });
-    offsets[loopRange.index()] = loop.getInductionVar();
+    Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+    Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+    Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+    auto loop =
+        rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
+                                    [](OpBuilder &bodyBuilder, Location bodyLoc,
+                                       Value iv, ValueRange /*iterArgs*/) {});
     loops.push_back(loop);
-    builder.setInsertionPointToEnd(loop.getBody());
+    ivs.push_back(loop.getInductionVar());
+    rewriter.setInsertionPointToEnd(loop.getBody());
     destinationTensors = loop.getRegionIterArgs();
   }
 
+  SmallVector<Value> tiledResults;
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
+                                tiledResults, resultOffsets, resultSizes))) {
+    return rewriter.notifyMatchFailure(
+        loc, "failed to generate inner tile loop body");
+  }
+  if (loops.empty())
+    return success();
+
+  // 6. Yield all the results of the tiled operation.
+  SmallVector<Value> yieldedValues;
+  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
+       llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
+                       resultSizes)) {
+    SmallVector<OpFoldResult> resultStride(resultOffset.size(),
+                                           rewriter.getIndexAttr(1));
+    auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
+        loc, tiledValue, destinationTensor, resultOffset, resultSize,
+        resultStride);
+    yieldedValues.push_back(insertSlice);
+  }
+  rewriter.create<scf::YieldOp>(loc, yieldedValues);
+
   // Add the scf.yield operations for all the outer loops.
-  if (!loops.empty()) {
-    for (auto [outerLoop, innerLoop] :
-         llvm::zip_equal(MutableArrayRef(loops).drop_back(),
-                         MutableArrayRef(loops).drop_front())) {
-      builder.setInsertionPointToEnd(outerLoop.getBody());
-      builder.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop.getResults());
-    }
+  for (auto [outerLoop, innerLoop] :
+       llvm::zip_equal(MutableArrayRef(loops).drop_back(),
+                       MutableArrayRef(loops).drop_front())) {
+    rewriter.setInsertionPointToEnd(
+        cast<scf::ForOp>(outerLoop.getOperation()).getBody());
+    rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
   }
-  return loops;
+  return success();
+}
+
+/// Generate the tile-loop nest using `scf.forall` operation.
+/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
+/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
+/// - `destinationTensors` are the init values to use for the outer most loop.
+/// - `mappingVector` is the mapping attributes to use for loop construction.
+///   Can be empty.
+/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
+/// most
+///    loop.
+/// - `loops` is an in-out parameter into which the generated loops are
+///    populated.
+static LogicalResult generateLoopNestUsingForallOp(
+    RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+    ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
+    ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
+    SmallVector<LoopLikeOpInterface> &loops) {
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  assert(!loopRanges.empty() && "unexpected empty loop ranges");
+  assert(loopRanges.size() == tileSizes.size() &&
+         "expected as many tile sizes as loop ranges");
+  OpBuilder::InsertionGuard guard(rewriter);
+  SmallVector<OpFoldResult> offsets(loopRanges.size()),
+      sizes(loopRanges.size());
+
+  for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
+    if (isConstantIntValue(tileSize, 0))
+      continue;
+    lbs.push_back(loopRange.offset);
+    ubs.push_back(loopRange.size);
+    steps.push_back(tileSize);
+  }
+  assert(!lbs.empty() && "Expected at least one loop range");
+
+  std::optional<ArrayAttr> mappingAttr;
+  if (!mappingVector.empty())
+    mappingAttr = rewriter.getArrayAttr(mappingVector);
+
+  auto forallOp = rewriter.create<scf::ForallOp>(
+      loc, lbs, ubs, steps, destinationTensors, mappingAttr);
+  loops.push_back(forallOp);
+
+  rewriter.setInsertionPoint(forallOp.getTerminator());
+  destinationTensors = forallOp.getRegionOutArgs();
+
+  SmallVector<Value> tiledResults;
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
+                         destinationTensors, tiledResults, resultOffsets,
+                         resultSizes)))
+    return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
+
+  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
+  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
+       llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
+                       resultSizes)) {
+    SmallVector<OpFoldResult> resultStride(resultOffset.size(),
+                                           rewriter.getIndexAttr(1));
+
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        loc, tiledValue, destinationTensor, resultOffset, resultSize,
+        resultStride);
+  }
+  return success();
+}
+
+/// Generate the tile-loop nest using the loop construct specifed in `options`.
+/// - `options`: Tiling options specified.
+/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
+/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
+/// - `destinationTensors` are the init values to use for the outer most loop.
+/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
+/// most
+///    loop.
+/// - `loops` is an in-out parameter into which the generated loops are
+///    populated.
+static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
+                                      const scf::SCFTilingOptions &options,
+                                      ArrayRef<Range> loopRanges,
+                                      ArrayRef<OpFoldResult> tileSizes,
+                                      ValueRange destinationTensors,
+                                      YieldTiledValuesFn tiledBodyFn,
+                                      SmallVector<LoopLikeOpInterface> &loops) {
+  // If the tile sizes are all zero, no loops are generated. Just call the
+  // callback function to handle untiled case.
+  if (llvm::all_of(tileSizes, isZeroIndex)) {
+    SmallVector<Value> tiledResults;
+    SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+    return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
+                       tiledResults, resultOffsets, resultSizes);
+  }
+  if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
+    return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
+                                      destinationTensors, tiledBodyFn, loops);
+  }
+  if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
+    return generateLoopNestUsingForallOp(
+        rewriter, loc, loopRanges, tileSizes, options.mappingVector,
+        destinationTensors, tiledBodyFn, loops);
+  }
+  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
 
 /// Method to add new init values to a loop nest. Updates `loops` in-place with
@@ -202,26 +292,28 @@ static SmallVector<scf::ForOp> generateTileLoopNest(
 /// The outer-loops are updated to yield the new result values of the inner
 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
 /// the additional values to yield form the innermost loop.
-static void addInitOperandsToLoopNest(
-    RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loops,
-    ValueRange newInitValues,
-    llvm::function_ref<SmallVector<Value>(RewriterBase &rewriter, Value iv,
-                                          ValueRange newRegionIterArgs)>
-        getNewYieldValsFn) {
+static LogicalResult addInitOperandsToLoopNest(
+    RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
+    ValueRange newInitValues, const YieldTiledValuesFn &getNewTiledYieldsFn) {
   SmallVector<scf::ForOp> newLoops;
   if (loops.empty())
-    return;
+    return success();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(loops.front());
-  for (auto &loop : loops) {
+
+  SmallVector<Value> ivs;
+  for (auto &loop : loops.drop_back()) {
     rewriter.setInsertionPoint(loop);
 
+    // if loops.size() > 1 we assume that scf.for is used for the loops.
+    auto forLoop = cast<scf::ForOp>(loop.getOperation());
+
     // Create a new loop with the new init values for this loop.
-    SmallVector<Value> newInits = llvm::to_vector(loop.getInitArgs());
+    SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
     newInits.append(newInitValues.begin(), newInitValues.end());
     auto newLoop = rewriter.create<scf::ForOp>(
-        loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
-        loop.getStep(), newInits,
+        forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
+        forLoop.getStep(), newInits,
         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
 
     // Merge the body of the new loop with the body of the old loops.
@@ -230,48 +322,49 @@ static void addInitOperandsToLoopNest(
     auto newRegionIterArgs = newLoop.getRegionIterArgs();
     sourceBlockArgs.append(
         newRegionIterArgs.begin(),
-        std::next(newRegionIterArgs.begin(), loop.getNumResults()));
-    rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(), sourceBlockArgs);
-    rewriter.replaceOp(loop,
-                       newLoop.getResults().take_front(loop.getNumResults()));
+        std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
+    rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
+    rewriter.replaceOp(
+        forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
     loop = newLoop;
+    ivs.push_back(newLoop.getInductionVar());
     newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
   }
 
   // Update the loop body of the innermost loop to get new yield values.
-  scf::ForOp innerMostLoop = loops.back();
-  auto innerMostYieldOp =
-      cast<scf::YieldOp>(innerMostLoop.getBody()->getTerminator());
-  rewriter.setInsertionPoint(innerMostYieldOp);
-  SmallVector<Value> newYieldVals =
-      getNewYieldValsFn(rewriter, innerMostLoop.getInductionVar(),
-                        innerMostLoop.getRegionIterArgs());
-  SmallVector<Value> newYieldOperands =
-      llvm::to_vector(innerMostYieldOp->getOperands());
-  newYieldOperands.append(newYieldVals);
-  rewriter.replaceOpWithNewOp<scf::YieldOp>(innerMostYieldOp, newYieldOperands);
+  LoopLikeOpInterface innerMostLoop = loops.back();
+  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
+      innerMostLoop.yieldTiledValuesAndReplace(rewriter, newInitValues,
+                                               getNewTiledYieldsFn);
+
+  if (failed(newInnerMostLoop))
+    return innerMostLoop.emitOpError("failed to return additional yields");
+  loops.back() = newInnerMostLoop.value();
 
   // Make all other loops except the innermost loops yield the values returned
   // by the inner loop.
   for (auto [outerLoop, innerLoop] :
        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+    // Again assume that all the outer loops are scf.for operations.
+    auto outerForLoop = cast<scf::ForOp>(outerLoop);
     auto outerLoopYield =
-        cast<scf::YieldOp>(outerLoop.getBody()->getTerminator());
+        cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
     SmallVector<Value> newYields =
         llvm::to_vector(outerLoopYield.getOperands());
     ValueRange additionalYields =
-        innerLoop.getResults().take_back(newInitValues.size());
+        innerLoop->getResults().take_back(newInitValues.size());
     newYields.append(additionalYields.begin(), additionalYields.end());
     rewriter.setInsertionPoint(outerLoopYield);
     rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
   }
+  return success();
 }
 
 /// Implementation of tiling transformation of `op` that implements the
 /// `TilingInterface` using `scf.for` to iterate over the tiles.
 FailureOr<scf::SCFTilingResult>
-mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
-                             const scf::SCFTilingOptions &options) {
+mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
+                        const scf::SCFTilingOptions &options) {
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointAfter(op);
 
@@ -288,145 +381,135 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
   // dimensions.
-  SmallVector<OpFoldResult> tileSizeVector =
+  SmallVector<OpFoldResult> tileSizes =
       options.tileSizeComputationFunction(rewriter, op);
-  if (tileSizeVector.size() < iterationDomain.size()) {
+  if (tileSizes.size() < iterationDomain.size()) {
     auto zero = rewriter.getIndexAttr(0);
-    tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
+    tileSizes.append(numLoops - tileSizes.size(), zero);
   }
 
-  // 3. Find the destination tensors to use for the operation.
-  SmallVector<Value> destinationTensors;
-  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
-                                             destinationTensors))) {
-    return rewriter.notifyMatchFailure(op,
-                                       "unable to create destination tensors");
+  // 3. If there is an interchange specified, permute the iteration domain and
+  // the tile sizes.
+  SmallVector<int64_t> interchangeVector;
+  if (!options.interchangeVector.empty()) {
+    interchangeVector = fillInterchangeVector(options.interchangeVector,
+                                              iterationDomain.size());
   }
-
-  SmallVector<OpFoldResult> offsets, sizes;
-  SmallVector<scf::ForOp> forLoops;
-  {
-    // If there is an interchange specified, permute the iteration domain and
-    // the tile sizes.
-    SmallVector<int64_t> interchangeVector;
-    if (!options.interchangeVector.empty()) {
-      interchangeVector = fillInterchangeVector(options.interchangeVector,
-                                                iterationDomain.size());
+  if (!interchangeVector.empty()) {
+    if (!isPermutationVector(interchangeVector)) {
+      return rewriter.notifyMatchFailure(
+          op, "invalid intechange vector, not a permutation of the entire "
+              "iteration space");
     }
-    if (!interchangeVector.empty()) {
-      if (!isPermutationVector(interchangeVector)) {
-        return rewriter.notifyMatchFailure(
-            op, "invalid intechange vector, not a permutation of the entire "
-                "iteration space");
-      }
 
-      applyPermutationToVector(iterationDomain, interchangeVector);
-      applyPermutationToVector(tileSizeVector, interchangeVector);
+    applyPermutationToVector(iterationDomain, interchangeVector);
+    applyPermutationToVector(tileSizes, interchangeVector);
+  }
+
+  FailureOr<TilingResult> tilingResult;
+  // 4. Define the lambda function used later to generate the body of the
+  // innermost tiled loop.
+  YieldTiledValuesFn innerYieldTiledValuesFn =
+      [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
+          ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
+          SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+          SmallVector<SmallVector<OpFoldResult>> &resultSizes)
+      -> LogicalResult {
+    // 4a. Compute the `offsets` and `sizes` to use for tiling.
+    SmallVector<OpFoldResult> offsets, sizes;
+    {
+      int materializedLoopNum = 0;
+      for (auto [tileSize, loopRange] :
+           llvm::zip_equal(tileSizes, iterationDomain)) {
+        if (isConstantIntValue(tileSize, 0)) {
+          offsets.push_back(loopRange.offset);
+          sizes.push_back(loopRange.size);
+          continue;
+        }
+        Value iv = ivs[materializedLoopNum++];
+        offsets.push_back(iv);
+        sizes.push_back(
+            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+      }
     }
 
-    // 4. Materialize an empty loop nest that iterates over the tiles. These
-    // loops for now do not return any values even if the original operation has
-    // results.
-    forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
-                                    tileSizeVector, offsets, sizes,
-                                    destinationTensors);
-
+    // 4b. If interchange was provided, apply inverse of the interchange
+    //     to get back the offsets/sizes in the order to be specified.
     if (!interchangeVector.empty()) {
       auto inversePermutation = invertPermutationVector(interchangeVector);
       applyPermutationToVector(offsets, inversePermutation);
       applyPermutationToVector(sizes, inversePermutation);
     }
-  }
 
-  LLVM_DEBUG({
-    if (!forLoops.empty()) {
-      llvm::dbgs() << "LoopNest shell :\n";
-      forLoops.front().dump();
-      llvm::dbgs() << "\n";
-    }
-  });
+    // 5. Generate the tiled implementation within the inner most loop.
 
-  // 5. Generate the tiled implementation within the inner most loop.
-  SmallVector<Value> clonedOpDestination = destinationTensors;
-  if (!forLoops.empty()) {
-    rewriter.setInsertionPointToEnd(forLoops.back().getBody());
-    clonedOpDestination =
-        llvm::map_to_vector(forLoops.back().getRegionIterArgs(),
-                            [](BlockArgument b) -> Value { return b; });
-  }
+    // 5a. Clone the operation within the loop body.
+    auto clonedOp = cast<TilingInterface>(
+        cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
 
-  // 5a. Clone the operation within the loop body.
-  auto clonedOp = cast<TilingInterface>(
-      cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));
+    // 5b. Early return cloned op if tiling is not happening. We can not return
+    // the original op because it could lead to
+    // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
+    if (llvm::all_of(tileSizes, isZeroIndex)) {
+      tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
+      tilingResult =
+          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+      return success();
+    }
 
-  // 5b. Early return cloned op if tiling is not happening. We can not return
-  // the original op because it could lead to
-  // `rewriter.replaceOp(op, op->getResults())` and user would get crash.
-  if (llvm::all_of(tileSizeVector, isZeroIndex)) {
-    return scf::SCFTilingResult{/*tiledOps=*/{clonedOp}, /*loops=*/{},
-                                clonedOp->getResults()};
-  }
+    // 5c. Tile the cloned operation.
+    tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
+    if (failed(tilingResult))
+      return op.emitOpError("failed to tile operation");
+
+    // 5d. Delete the cloned operation.
+    rewriter.eraseOp(clonedOp);
+
+    // 5e. Compute the offsets at which the result values are to be inserted
+    //     back into its destinations.
+    for (auto [index, tiledValue] :
+         llvm::enumerate(tilingResult->tiledValues)) {
+      tiledResults.push_back(tiledValue);
+      SmallVector<OpFoldResult> resultOffset, resultSize;
+      if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
+                                          resultOffset, resultSize))) {
+        return rewriter.notifyMatchFailure(
+            op, "failed to get slice of result produced");
+      }
+      resultOffsets.emplace_back(std::move(resultOffset));
+      resultSizes.emplace_back(std::move(resultSize));
+    }
+
+    return success();
+  };
 
-  // 5c. Tile the cloned operation.
-  FailureOr<TilingResult> tiledImplementation =
-      clonedOp.getTiledImplementation(rewriter, offsets, sizes);
-  if (failed(tiledImplementation)) {
-    return rewriter.notifyMatchFailure(op, "failed to tile operation");
+  // 6. Find the destination tensors to use for the operation.
+  SmallVector<Value> destinationTensors;
+  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
+                                             destinationTensors))) {
+    return rewriter.notifyMatchFailure(op,
+                                       "unable to create destination tensors");
   }
 
-  // 5d. Delete the cloned operation.
-  rewriter.eraseOp(clonedOp);
+  // 7. Generate the tiled loops nest using the callback defined above.
+  SmallVector<LoopLikeOpInterface> loops;
+  if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
+                              tileSizes, destinationTensors,
+                              innerYieldTiledValuesFn, loops)))
+    return op.emitOpError("failed to generate tiling loops");
+  assert(succeeded(tilingResult) &&
+         "expected tiling result to be computed after loop generation");
 
   // If loops are empty, the tiled op is used as the replacement for the untiled
   // op.
-  if (forLoops.empty()) {
-    return scf::SCFTilingResult{tiledImplementation->tiledOps,
-                                getAsOperations(forLoops),
-                                tiledImplementation->tiledValues};
-  }
-
-  if (op->getNumResults() == 0) {
-    // The innermost loop does not have a `scf.yield` yet. There is nothing to
-    // return, so generate an empty `scf.yield` operation.
-    rewriter.setInsertionPointToEnd(forLoops.back().getBody());
-    rewriter.create<scf::YieldOp>(op->getLoc());
-    return scf::SCFTilingResult{
-        tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
-  }
-
-  // 6. Yield all the results of the tiled operation.
-  int64_t numResults = op->getNumResults();
-  SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
-      resultSizesList(numResults);
-  SmallVector<Value> yieldedValues;
-  for (auto [index, tiledValue] :
-       llvm::enumerate(tiledImplementation->tiledValues)) {
-    SmallVector<OpFoldResult> resultOffsets, resultSizes;
-    if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
-                                        resultOffsets, resultSizes))) {
-      return rewriter.notifyMatchFailure(
-          op, "failed to get slice of result produced");
-    }
-    SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
-                                            rewriter.getIndexAttr(1));
-    auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
-        op->getLoc(), tiledValue, clonedOpDestination[index], resultOffsets,
-        resultSizes, resultStrides);
-    yieldedValues.push_back(insertSlice);
+  if (loops.empty()) {
+    return scf::SCFTilingResult{tilingResult->tiledOps, loops,
+                                tilingResult->tiledValues};
   }
-  rewriter.create<scf::YieldOp>(op->getLoc(), yieldedValues);
 
   SmallVector<Value> replacements = llvm::map_to_vector(
-      forLoops.front().getResults(), [](OpResult r) -> Value { return r; });
-  LLVM_DEBUG({
-    if (!forLoops.empty()) {
-      llvm::dbgs() << "After tiled implementation :\n";
-      forLoops.front().dump();
-      llvm::dbgs() << "\n";
-    }
-  });
-  return scf::SCFTilingResult{tiledImplementation->tiledOps,
-                              getAsOperations(forLoops), replacements};
+      loops.front()->getResults(), [](OpResult r) -> Value { return r; });
+  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
 }
 
 FailureOr<scf::SCFReductionTilingResult>
@@ -464,50 +547,72 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   if (failed(identityTensor))
     return b.notifyMatchFailure(op,
                                 "cannot create a tensor of identity value.");
-  // 3. Create the nested loops.
-  SmallVector<OpFoldResult> offsets, sizes;
-  SmallVector<scf::ForOp> loops =
-      generateTileLoopNest(b, loc, iterationDomain, tileSizesVector, offsets,
-                           sizes, identityTensor.value()->getResults());
-
-  // 4. Generate the tiled implementation within the inner most loop.
-  // 4a. Clone the operation within the loop body.
-  SmallVector<Value> clonedOpDestination =
+
+  // 3. Define the callback to use for generating the inner most tile loop body.
+  Operation *parallelOp = nullptr;
+  auto innerYieldTiledValuesFn =
+      [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
+          ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
+          SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+          SmallVector<SmallVector<OpFoldResult>> &resultSizes)
+      -> LogicalResult {
+    SmallVector<OpFoldResult> offsets, sizes;
+    {
+      int materializedLoopNum = 0;
+      for (auto [tileSize, loopRange] :
+           llvm::zip_equal(tileSizesVector, iterationDomain)) {
+        if (isConstantIntValue(tileSize, 0)) {
+          offsets.push_back(loopRange.offset);
+          sizes.push_back(loopRange.size);
+          continue;
+        }
+        Value iv = ivs[materializedLoopNum++];
+        offsets.push_back(iv);
+        sizes.push_back(
+            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+      }
+    }
+
+    // 4a. Clone the operation.
+    auto clonedOp = cast<PartialReductionOpInterface>(
+        cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+
+    // 4b. Tile the cloned operation.
+    parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs,
+                                                 offsets, sizes, reductionDims);
+    // 4c. Delete the cloned operation.
+    b.eraseOp(clonedOp);
+
+    tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
+    // 4d. Compute the offsets and sizes needed to insert the result of the
+    // tiled value back into destination before yielding the destination.
+    SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
+    resultOffsets.emplace_back(std::move(outOffsets));
+
+    SmallVector<OpFoldResult> outSizes;
+    for (size_t i = 0; i < offsets.size(); i++) {
+      outSizes.push_back(
+          tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
+    }
+    resultSizes.emplace_back(std::move(outSizes));
+    return success();
+  };
+
+  // 5. Generate the tiled implementation using the destination tensors.
+  SmallVector<Value> destinationTensors =
       llvm::map_to_vector(identityTensor.value()->getResults(),
                           [](OpResult res) -> Value { return res; });
-  if (!loops.empty()) {
-    b.setInsertionPointToEnd(loops.back().getBody());
-    clonedOpDestination =
-        llvm::map_to_vector(loops.back().getRegionIterArgs(),
-                            [](BlockArgument b) -> Value { return b; });
-  }
-  auto clonedOp = cast<PartialReductionOpInterface>(
-      cloneOpAndUpdateDestinationArgs(b, op, clonedOpDestination));
-
-  // 4b. Tile the cloned operation.
-  Operation *parallelOp = clonedOp.tileToPartialReduction(
-      b, loc, clonedOpDestination, offsets, sizes, reductionDims);
-  // 4c. Delete the cloned operation.
-  b.eraseOp(clonedOp);
-
-  SmallVector<OpFoldResult> outSizes;
-  for (size_t i = 0; i < offsets.size(); i++) {
-    outSizes.push_back(
-        tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
-  }
-  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-  SmallVector<OpFoldResult> outStrides(outOffsets.size(), b.getIndexAttr(1));
-  SmallVector<Value> yieldedVals;
-  auto bbArgs = loops.back().getRegionIterArgs();
-  for (auto [result, bbArg] : llvm::zip(parallelOp->getResults(), bbArgs)) {
-    Value insert = b.create<tensor::InsertSliceOp>(
-        loc, result, bbArg, outOffsets, outSizes, outStrides);
-    yieldedVals.push_back(insert);
-  }
-  b.create<scf::YieldOp>(loc, yieldedVals);
+
+  SmallVector<LoopLikeOpInterface> loops;
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+  if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
+                              destinationTensors, innerYieldTiledValuesFn,
+                              loops)))
+    return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
 
   SmallVector<Value> replacements = llvm::map_to_vector(
-      loops.front().getResults(), [](OpResult r) -> Value { return r; });
+      loops.front()->getResults(), [](OpResult r) -> Value { return r; });
 
   // 5. Apply the merge reduction to combine all the partial values.
   b.setInsertionPointAfter(*loops.begin());
@@ -516,14 +621,14 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 
   SCFReductionTilingResult results;
   results.initialOp = *identityTensor;
-  results.loops = std::move(loops);
+  results.loops = loops;
   results.parallelTiledOp = parallelOp;
   results.mergeOp = mergeOp;
   return results;
 }
 
 //===----------------------------------------------------------------------===//
-// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
+// tileConsumerAndFuseProducersUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
 /// Return the untiled producer whose slice is used in a tiled consumer. The
@@ -533,11 +638,11 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 /// no loop traversal needed, the second value of the returned tuple is empty.
 static std::tuple<OpResult, std::optional<OpOperand *>>
 getUntiledProducerFromSliceSource(OpOperand *source,
-                                  ArrayRef<scf::ForOp> loops) {
+                                  ArrayRef<LoopLikeOpInterface> loops) {
   std::optional<OpOperand *> destinationIterArg;
   auto loopIt = loops.rbegin();
   while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
-    scf::ForOp loop = *loopIt;
+    auto loop = *loopIt;
     if (iterArg.getOwner()->getParentOp() != loop)
       break;
     source = loop.getTiedLoopInit(iterArg);
@@ -551,9 +656,9 @@ getUntiledProducerFromSliceSource(OpOperand *source,
 /// Implementation of fusing producer of a single slice by computing the
 /// slice of the producer in-place.
 std::optional<scf::SCFFuseProducerOfSliceResult>
-mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
-                                      tensor::ExtractSliceOp candidateSliceOp,
-                                      MutableArrayRef<scf::ForOp> loops) {
+mlir::scf::tileAndFuseProducerOfSlice(
+    RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
   auto [fusableProducer, destinationInitArg] =
@@ -666,12 +771,12 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 }
 
 /// Reconstruct the fused producer from within the tiled-and-fused code.
-void mlir::scf::yieldReplacementForFusedProducer(
+LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<scf::ForOp> loops) {
+    MutableArrayRef<LoopLikeOpInterface> loops) {
   if (loops.empty())
-    return;
+    return success();
 
   OpResult fusableProducer = fusedProducerInfo.origProducer;
   Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
@@ -679,15 +784,18 @@ void mlir::scf::yieldReplacementForFusedProducer(
       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
   if (succeeded(initValue)) {
 
-    auto newYieldValuesFn =
-        [&](RewriterBase &innerRewriter, Value iv,
-            ValueRange newRegionIterArgs) -> SmallVector<Value> {
+    YieldTiledValuesFn newYieldValuesFn =
+        [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+            ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+            SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+            SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
+        -> LogicalResult {
       OpBuilder::InsertionGuard g(innerRewriter);
       if (auto tiledDestStyleOp =
               tiledAndFusedProducer
                   .getDefiningOp<DestinationStyleOpInterface>()) {
         rewriter.setInsertionPoint(tiledDestStyleOp);
-        BlockArgument newRegionArg = loops.back().getRegionIterArgs().back();
+        Value newRegionArg = newRegionIterArgs.back();
         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
             sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
             sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
@@ -698,23 +806,22 @@ void mlir::scf::yieldReplacementForFusedProducer(
       }
       Block *block = rewriter.getInsertionPoint()->getBlock();
       rewriter.setInsertionPoint(block->getTerminator());
-      Value replacement = rewriter.create<tensor::InsertSliceOp>(
-          fusedProducerInfo.origProducer.getLoc(),
-          fusedProducerInfo.tiledAndFusedProducer,
-          loops.back().getRegionIterArgs().back(), sliceOp.getMixedOffsets(),
-          sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
-      return {replacement};
+      tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
+      tiledOffset.emplace_back(sliceOp.getMixedOffsets());
+      tiledSizes.emplace_back(sliceOp.getMixedSizes());
+      return success();
     };
 
-    addInitOperandsToLoopNest(rewriter, loops,
-                              SmallVector<Value>{initValue.value()},
-                              newYieldValuesFn);
+    return addInitOperandsToLoopNest(rewriter, loops,
+                                     SmallVector<Value>{initValue.value()},
+                                     newYieldValuesFn);
   }
+  return success();
 }
 
 /// Implementation of tile consumer and fuse producer greedily.
 FailureOr<scf::SCFTileAndFuseResult>
-mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
+mlir::scf::tileConsumerAndFuseProducersUsingSCF(
     RewriterBase &rewriter, TilingInterface consumer,
     const scf::SCFTileAndFuseOptions &options) {
   // This transformation is only valid for ops that return values (i.e. not
@@ -727,24 +834,25 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
   // 1. First tile the consumer.
   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
   llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
+
   FailureOr<scf::SCFTilingResult> tilingResult =
-      tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
+      tileUsingSCF(rewriter, consumer, options.tilingOptions);
+
   if (failed(tilingResult))
     return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
   for (auto *tiledOp : tilingResult->tiledOps)
     tiledAndFusedOps.insert(tiledOp);
-  SmallVector<scf::ForOp> forLoops =
-      castToTypedOperations<scf::ForOp>(tilingResult->loops);
 
   // If there are no loops generated, fusion is immaterial.
-  if (forLoops.empty()) {
+  auto &loops = tilingResult->loops;
+  if (loops.empty()) {
     DenseMap<Value, Value> replacements;
     for (auto [origVal, replacement] :
          llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
       replacements[origVal] = replacement;
     }
-    return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
-                                     getAsOperations(forLoops), replacements};
+    return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
+                                     replacements};
   }
 
   // To keep track of replacements for now just record the map from the original
@@ -780,7 +888,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     // Find the original producer of the slice.
     auto [fusableProducer, destinationInitArg] =
         getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
-                                          forLoops);
+                                          loops);
     if (!fusableProducer)
       continue;
 
@@ -793,15 +901,19 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
-        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
+        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
     if (!fusedResult)
       continue;
 
     if (yieldReplacement) {
-      yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
-                                       fusedResult.value(), forLoops);
+      if (failed(yieldReplacementForFusedProducer(
+              rewriter, candidateSliceOp, fusedResult.value(), loops))) {
+        return rewriter.notifyMatchFailure(
+            fusableProducer.getOwner(), "failed to replacement value for this "
+                                        "oepration from within the tiled loop");
+      }
       origValToResultNumber[fusableProducer] =
-          forLoops.front().getNumResults() - 1;
+          loops.front()->getNumResults() - 1;
     }
 
     if (Operation *tiledAndFusedOp =
@@ -814,123 +926,11 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
 
   DenseMap<Value, Value> replacements;
   for (auto [origVal, resultNumber] : origValToResultNumber) {
-    replacements[origVal] = forLoops.front()->getResult(resultNumber);
-  }
-
-  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
-                                   getAsOperations(forLoops), replacements};
-}
-
-//===----------------------------------------------------------------------===//
-// tileUsingSCFForAllOp implementation.
-//===----------------------------------------------------------------------===//
-
-FailureOr<scf::SCFTilingResult>
-mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
-                                const scf::SCFTilingOptions &options) {
-  Location loc = op->getLoc();
-  OpBuilder::InsertionGuard g(rewriter);
-
-  // 1. Get the range of loops that are represented by the operation.
-  SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
-  if (loopRanges.empty())
-    return op->emitOpError("expected non-empty loop ranges");
-  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
-  if (llvm::any_of(loopRanges, hasStrideOne))
-    return op->emitOpError("only stride-1 supported atm");
-
-  // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
-  // To make it easier, pad the tile sizes to loopRanges.size with value 0.
-  SmallVector<OpFoldResult> tileSizeVector =
-      options.tileSizeComputationFunction(rewriter, op);
-  tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
-
-  // 3. Build the offsets, sizes and steps for the tile and distributed loops.
-  SmallVector<OpFoldResult> lbs, ubs, steps;
-  for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
-    if (isConstantIntValue(tileSize, 0))
-      continue;
-    lbs.push_back(loopRange.offset);
-    ubs.push_back(loopRange.size);
-    steps.push_back(tileSize);
-  }
-
-  // 4. Gather destination tensors.
-  SmallVector<Value> dest;
-  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
-    return op->emitOpError("failed to get destination tensors");
-
-  // 5. Build the device mapping attribute.
-  std::optional<ArrayAttr> mappingAttr;
-  if (!options.mappingVector.empty()) {
-    mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
-  }
-
-  // 6. Create the ForallOp. We don't use the lambda body-builder
-  // version because we require the use of RewriterBase in the body, so we
-  // manually move the insertion point to the body below.
-  auto forallOp =
-      rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
-
-  // 7. Get the tile offset and sizes.
-  rewriter.setInsertionPoint(forallOp.getTerminator());
-  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
-  ValueRange ivs = forallOp.getInductionVars();
-  {
-    int materializedLoopNum = 0;
-    for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
-      if (isConstantIntValue(tileSize, 0)) {
-        tiledOffsets.push_back(loopRange.offset);
-        tiledSizes.push_back(loopRange.size);
-        continue;
-      }
-      Value iv = ivs[materializedLoopNum++];
-      tiledOffsets.push_back(iv);
-      tiledSizes.push_back(
-          getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
-    }
-  }
-
-  // 8. Tile the operation. Clone the operation to allow fix up of destination
-  // operands.
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
-  Operation *clonedOp =
-      cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
-  FailureOr<TilingResult> tilingResult =
-      cast<TilingInterface>(clonedOp).getTiledImplementation(
-          rewriter, tiledOffsets, tiledSizes);
-  if (failed(tilingResult))
-    return clonedOp->emitError("failed to tile op: ");
-  rewriter.eraseOp(clonedOp);
-
-  // 9. Parallel insert back into the result tensor.
-  for (auto [index, tiledValue, destBBArg] :
-       llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
-    // 9.a. Partial subset information is inserted just before the terminator.
-    rewriter.setInsertionPoint(forallOp.getTerminator());
-
-    SmallVector<OpFoldResult> resultOffsets, resultSizes;
-    if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
-                                        tiledSizes, resultOffsets,
-                                        resultSizes))) {
-      return op->emitOpError("output offsets couldn't be calculated");
-    }
-
-    SmallVector<OpFoldResult> strides(resultSizes.size(),
-                                      rewriter.getIndexAttr(1));
-    // 9.b. Parallel insertions are inserted at the end of the combining
-    // terminator.
-    rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
-    rewriter.create<tensor::ParallelInsertSliceOp>(
-        loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
+    replacements[origVal] = loops.front()->getResult(resultNumber);
   }
 
-  // 10. Return the tiling result.
-  return scf::SCFTilingResult{
-      tilingResult->tiledOps,
-      {forallOp.getOperation()},
-      llvm::map_to_vector(forallOp.getResults(),
-                          [](auto val) -> Value { return val; })};
+  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
+                                   replacements};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a2043c647d49a3..cdd85ddeb93add 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -933,11 +933,11 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
   fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
 
   // Map shared outs.
-  fusedMapping.map(target.getOutputBlockArguments(),
-                   fusedLoop.getOutputBlockArguments().slice(0, numTargetOuts));
+  fusedMapping.map(target.getRegionIterArgs(),
+                   fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
   fusedMapping.map(
-      source.getOutputBlockArguments(),
-      fusedLoop.getOutputBlockArguments().slice(numTargetOuts, numSourceOuts));
+      source.getRegionIterArgs(),
+      fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
 
   // Append everything except the terminator into the fused operation.
   rewriter.setInsertionPointToStart(fusedLoop.getBody());
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index be1316b95688bf..1e0e87b64e8113 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -63,8 +63,9 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
     return op->emitOpError("different number of inits and region iter_args: ")
            << loopLikeOp.getInits().size()
            << " != " << loopLikeOp.getRegionIterArgs().size();
-  if (loopLikeOp.getRegionIterArgs().size() !=
-      loopLikeOp.getYieldedValues().size())
+  if (!loopLikeOp.getYieldedValues().empty() &&
+      loopLikeOp.getRegionIterArgs().size() !=
+          loopLikeOp.getYieldedValues().size())
     return op->emitOpError(
                "different number of region iter_args and yielded values: ")
            << loopLikeOp.getRegionIterArgs().size()
@@ -78,21 +79,22 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
 
   // Verify types of inits/iter_args/yielded values/loop results.
   int64_t i = 0;
-  for (const auto it :
-       llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
-                       loopLikeOp.getYieldedValues())) {
-    if (std::get<0>(it).getType() != std::get<1>(it).getType())
-      return op->emitOpError(std::to_string(i))
-             << "-th init and " << i
-             << "-th region iter_arg have different type: "
-             << std::get<0>(it).getType()
-             << " != " << std::get<1>(it).getType();
-    if (std::get<1>(it).getType() != std::get<2>(it).getType())
-      return op->emitOpError(std::to_string(i))
-             << "-th region iter_arg and " << i
-             << "-th yielded value have different type: "
-             << std::get<1>(it).getType()
-             << " != " << std::get<2>(it).getType();
+  auto yieldedValues = loopLikeOp.getYieldedValues();
+  for (const auto [index, init, regionIterArg] :
+       llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) {
+    if (init.getType() != regionIterArg.getType())
+      return op->emitOpError(std::to_string(index))
+             << "-th init and " << index
+             << "-th region iter_arg have different type: " << init.getType()
+             << " != " << regionIterArg.getType();
+    if (!yieldedValues.empty()) {
+      if (regionIterArg.getType() != yieldedValues[index].getType())
+        return op->emitOpError(std::to_string(index))
+               << "-th region iter_arg and " << index
+               << "-th yielded value have different type: "
+               << regionIterArg.getType()
+               << " != " << yieldedValues[index].getType();
+    }
     ++i;
   }
   i = 0;
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
index 0e27c6a783e6f1..f0d4b790520e03 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
@@ -46,11 +46,11 @@ func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
 // CHECK:         %{{.+}} = scf.for %[[I:[a-zA-Z0-9]+]] =
-// CHECK:           %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]])
 // CHECK:           %{{.+}} = scf.for %[[J:[a-zA-Z0-9]+]] =
-// CHECK:             %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
-// CHECK:             %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
-// CHECK:             %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
+// CHECK-DAG:         %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]])
+// CHECK-DAG:         %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
+// CHECK-DAG:         %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
+// CHECK-DAG:         %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
 // CHECK:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
 // CHECK-SAME:          [%[[IN_I]], %[[IN_J]], 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
 // CHECK:             %[[ITER_SLICE:.+]] = tensor.extract_slice %{{[a-zA-Z0-9]+}}
diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index 4a940f12662e6c..c42bdbe982c4fa 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -30,8 +30,8 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:   %[[H:.*]] = memref.dim %[[ARG2]], %[[C0]]
 //   CHECK-DAG:   %[[W:.*]] = memref.dim %[[ARG2]], %[[C1]]
 //       CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[H]] step %[[C2]]
-//       CHECK:     %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[H]]]
 //       CHECK:     scf.for %[[J:.*]] = %[[C0]] to %[[W]] step %[[C3]]
+//   CHECK-DAG:     %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[H]]]
 //   CHECK-DAG:       %[[T5:.*]] = affine.min #[[MAP1]](%[[J]])[%[[W]]]
 //   CHECK-DAG:       %[[T6:.*]] = affine.apply #[[MAP2]](%[[T4]])[%[[KH]]]
 //   CHECK-DAG:       %[[T7:.*]] = affine.apply #[[MAP2]](%[[T5]])[%[[KW]]]
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index e8e63302286400..cdef71ded8b2ca 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -137,9 +137,9 @@ func.func @fold_extract_slice(
   //      CHECK:   %[[E:.*]] = tensor.extract_slice %[[ARG0]][3, 4] [%[[DIM]], 42] [1, 1] : tensor<?x128xf32> to tensor<?x42xf32>
 
   //      CHECK:    scf.for %[[IV0:[0-9a-zA-Z]*]] =
-  //      CHECK:      %[[SIZE0:.*]] = affine.min #[[MAP0]](%[[IV0]])[%[[DIM]]
   //      CHECK:      scf.for %[[IV1:[0-9a-zA-Z]*]] =
 
+  //      CHECK:      %[[SIZE0:.*]] = affine.min #[[MAP0]](%[[IV0]])[%[[DIM]]
   // Fold the existing extract slice op into the one created by the tiling.
   //      CHECK:        %[[T0:.*]] = tensor.extract_slice %[[E]]
   // CHECK-SAME:                                          %[[IV0]], %[[IV1]]
diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir
index 6bec1cbd65be68..1be5bf098c334c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir
@@ -177,8 +177,10 @@ module attributes {transform.with_named_sequence} {
     %pad = transform.get_producer_of_operand %matmul_padded[2]
       : (!transform.any_op) -> !transform.any_op
 
+    transform.apply_licm to %loops_l1#1 : !transform.any_op
+
     transform.structured.hoist_pad.build_packing_loop_nest %pad above %loops_l1#1
        : (!transform.any_op, !transform.any_op) -> !transform.any_op
-       transform.yield
+    transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
index b66db855035bcd..37cb9b2376fb43 100644
--- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
@@ -202,6 +202,8 @@ module attributes {transform.with_named_sequence} {
     %pad = transform.get_producer_of_operand %matmul_padded[2]
       : (!transform.any_op) -> !transform.op<"tensor.pad">
 
+    transform.apply_licm to %loops_l1#1 : !transform.any_op
+
     transform.structured.hoist_pad %pad by 1 loops
        : (!transform.op<"tensor.pad">) -> !transform.any_op
        transform.yield
diff --git a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize.mlir
index 762648050fdfdf..d54cace31efb99 100644
--- a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize.mlir
@@ -52,9 +52,9 @@ func.func @matmul(%A: tensor<1024x512xf32>,
 // CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write {{.*}} } : vector<8x[16]xi1> -> tensor<8x?xf32>
 // CHECK:           scf.yield %inserted_slice : tensor<1024x2000xf32>
 // CHECK:         }
-// CHECK:         scf.yield %7 : tensor<1024x2000xf32>
+// CHECK:         scf.yield {{.*}} : tensor<1024x2000xf32>
 // CHECK:       }
-// CHECK:       scf.yield %5 : tensor<1024x2000xf32>
+// CHECK:       scf.yield {{.*}} : tensor<1024x2000xf32>
 // CHECK-NEXT:    }
 
   %res = linalg.matmul ins(%A, %B: tensor<1024x512xf32>, tensor<512x2000xf32>)
diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
index bb42f84afc50f9..1afbd3d0504f74 100644
--- a/mlir/test/Dialect/Tensor/tiling.mlir
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -127,8 +127,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     else
 //       CHECK:       %[[SLICE:.*]] = tensor.extract_slice %[[IN]][0, {{.*}}] [7, {{.*}}] [1, 1]
 //       CHECK:       %[[PAD:.*]] = tensor.pad %[[SLICE]] low[3, %{{.*}}] high[5, {{.*}}]
-//       CHECK:     %[[CAST_SWAP_RESULT:.*]] = tensor.cast %[[SWAP_RESULT]] : tensor<?x?xf32> to tensor<15x?xf32>
-//       CHECK:     tensor.insert_slice %[[CAST_SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [15, {{.*}}] [1, 1]
+//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [15, {{.*}}] [1, 1]
 //       CHECK:   return %[[RESULT]]
 
 func.func @static_pad_tensor_0_3(%input_tensor: tensor<7x9xf32>,
@@ -158,15 +157,12 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:   %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C15]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
 //       CHECK:     %[[R2:.*]] = scf.if
 //       CHECK:       %[[GEN:.*]] = tensor.generate
-//       CHECK:       %[[cast_0:.*]] = tensor.cast %[[GEN]] : tensor<14x3xf32> to tensor<?x3xf32>
-//       CHECK:       scf.yield %[[cast_0]] : tensor<?x3xf32>
+//       CHECK:       scf.yield %[[GEN]] : tensor<14x3xf32>
 //       CHECK:     else
 //       CHECK:       %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
 //       CHECK:       %[[PAD:.*]] = tensor.pad %[[SLICE]] low[0, 0] high[7, %{{.*}}]
-//       CHECK:       %[[cast_1:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<?x3xf32>
-//       CHECK:       scf.yield %[[cast_1]] : tensor<?x3xf32>
-//       CHECK:     %[[cast:.*]] = tensor.cast %[[R2]] : tensor<?x3xf32> to tensor<14x3xf32>
-//       CHECK:     %[[R3:.*]] = tensor.insert_slice %[[cast]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
+//       CHECK:       scf.yield %[[PAD]] : tensor<14x3xf32>
+//       CHECK:     %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
 //       CHECK:     scf.yield %[[R3]] : tensor<14x15xf32>
 //       CHECK:   return %[[RESULT]] : tensor<14x15xf32>
 
@@ -312,8 +308,8 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:       %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor<?x?x8x2xf32>
 // CHECK-DAG:       %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor<?x?x8x2xf32>
 // CHECK:           %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<?x?x8x2xf32>) {
-// CHECK-DAG:         %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
 // CHECK:             %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<?x?x8x2xf32>) {
+// CHECK-DAG:           %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
 // CHECK-DAG:           %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]]
 // CHECK-DAG:           %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])
 // CHECK-DAG:           %[[IN_I_SZ:.*]] = affine.min #[[MAP3]]
@@ -364,11 +360,11 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:       %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor<?x?x?x?xf32>
 // CHECK-DAG:       %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor<?x?x?x?xf32>
 // CHECK:           %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<?x?x?x?xf32>) {
-// CHECK:             %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
 // CHECK:             %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<?x?x?x?xf32>) {
-// CHECK:               %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]]
-// CHECK:               %[[IN_D0:.*]] = tensor.dim %[[IN]], %[[C0]]
-// CHECK:               %[[IN_D1:.*]] = tensor.dim %[[IN]], %[[C1]]
+// CHECK-DAG:           %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
+// CHECK-DAG:           %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]]
+// CHECK-DAG:           %[[IN_D0:.*]] = tensor.dim %[[IN]], %[[C0]]
+// CHECK-DAG:           %[[IN_D1:.*]] = tensor.dim %[[IN]], %[[C1]]
 // CHECK:               %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])[%[[TILE_0]]]
 // CHECK:               %[[IN_I_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_I_SZ]], %[[I]])[%[[TILE_0]], %[[IN_D0]]]
 // CHECK:               %[[IN_J:.*]] = affine.apply #[[MAP2]](%[[J]])[%[[TILE_1]]]
@@ -550,8 +546,8 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:     %[[DIM_0:.+]] = tensor.dim %[[OUT]], %[[C0]]
 // CHECK-DAG:     %[[DIM_1:.+]] = tensor.dim %[[OUT]], %[[C1]]
 // CHECK:         %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[DIM_0]] step %[[C2]]
-// CHECK-DAG:       %[[OUT_K_SZ:.+]] = affine.min #[[MAP0]](%[[K]])[%[[DIM_0]]]
 // CHECK:           %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[DIM_1]] step %[[C4]]
+// CHECK-DAG:         %[[OUT_K_SZ:.+]] = affine.min #[[MAP0]](%[[K]])[%[[DIM_0]]]
 // CHECK-DAG:         %[[OUT_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]])[%[[DIM_1]]]
 // CHECK-DAG:         %[[IN_K:.+]] = affine.apply #[[MAP2]](%[[K]])
 // CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir
new file mode 100644
index 00000000000000..0bd2546e082b5a
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir
@@ -0,0 +1,176 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+
+func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %gemm : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_using_forall %matmul [10, 20]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @gemm_fill_fusion(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty
+//      CHECK:   scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+// CHECK-SAME:       shared_outs(%[[ITERARG0:.+]] = %[[INIT]])
+//  CHECK-DAG:     %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+//  CHECK-DAG:     %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+//  CHECK-DAG:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV0]], %[[IV1]]]
+//      CHECK:     %[[FILL_TILE:.+]] = linalg.fill
+// CHECK-SAME:         outs(%[[INIT_TILE]] :
+//      CHECK:     %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:         outs(%[[FILL_TILE]] :
+//      CHECK:     scf.forall.in_parallel {
+//      CHECK:       tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[ITERARG0]][%[[IV0]], %[[IV1]]]
+//      CHECK:     }
+
+// -----
+
+func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %generic = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<?x?xf32>
+  return %generic : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_using_forall %generic [10, 20]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @gemm_generic_fusion(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty
+//      CHECK:   scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+// CHECK-SAME:       shared_outs(%[[ITERARG0:.+]] = %[[INIT]])
+//  CHECK-DAG:     %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+//  CHECK-DAG:     %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+//  CHECK-DAG:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
+//      CHECK:     %[[FILL_TILE:.+]] = linalg.fill
+// CHECK-SAME:         outs(%[[INIT_TILE]] :
+//      CHECK:     %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:         outs(%[[FILL_TILE]] :
+//  CHECK-DAG:     %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]]
+//  CHECK-DAG:     %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV0]], %[[IV1]]]
+//      CHECK:     %[[GENERIC_TILE:.+]] = linalg.generic
+// CHECK-SAME:         ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
+// CHECK-SAME:         outs(%[[OUTS_TILE]] :
+//      CHECK:     scf.forall.in_parallel {
+//      CHECK:       tensor.parallel_insert_slice %[[GENERIC_TILE]] into %[[ITERARG0]][%[[IV0]], %[[IV1]]]
+//      CHECK:     }
+
+// -----
+
+func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_0 = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<30xf32>
+  %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
+  %2 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%arg0 : tensor<30x3xf32>) outs(%1 : tensor<30xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %8 = arith.maximumf %arg2, %arg1 : f32
+      linalg.yield %8 : f32
+    } -> tensor<30xf32>
+  %3 = tensor.empty() : tensor<30x3xf32>
+  %4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
+  %5:2 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
+                       affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%arg0, %2 : tensor<30x3xf32>, tensor<30xf32>) outs(%4, %3 : tensor<30xf32>, tensor<30x3xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
+      %8 = arith.subf %arg1, %arg2 : f32
+      %9 = math.exp %8 : f32
+      %10 = arith.addf %arg3, %9 : f32
+      linalg.yield %10, %9 : f32, f32
+    } -> (tensor<30xf32>, tensor<30x3xf32>)
+  %6 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
+                       affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%5#1, %5#0 : tensor<30x3xf32>, tensor<30xf32>) outs(%3 : tensor<30x3xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+      %8 = arith.divf %arg1, %arg2 : f32
+      linalg.yield %8 : f32
+    } -> tensor<30x3xf32>
+  return %6 : tensor<30x3xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %generic1, %generic2, %generic3 = transform.split_handle %generics
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_using_forall %generic3 [10]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//       CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>)
+//   CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<30xf32>
+//   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<30x3xf32>
+//       CHECK:   %[[RESULT:[a-zA-Z0-9]+]] = scf.forall (%[[IV:[a-zA-Z0-9]+]])
+//  CHECK-SAME:       shared_outs(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]])
+//   CHECK-DAG:     %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0]
+//   CHECK-DAG:     %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]]]
+//       CHECK:     %[[FILL0:.+]] = linalg.fill
+//  CHECK-SAME:         outs(%[[INIT0_SLICE]] :
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[ARG0_SLICE]] :
+//  CHECK-SAME:         outs(%[[FILL0]] :
+//       CHECK:     %[[FILL1:.+]] = linalg.fill
+//  CHECK-SAME:         outs(%[[INIT0_SLICE]] :
+//       CHECK:     %[[INIT1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
+//       CHECK:     %[[GENERIC1:.+]]:2 = linalg.generic
+//  CHECK-SAME:         ins(%[[ARG0_SLICE]], %[[GENERIC0]] :
+//  CHECK-SAME:         outs(%[[FILL1]], %[[INIT1_SLICE]] :
+//       CHECK:     %[[ITERARG0_SLICE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+//       CHECK:     %[[GENERIC2:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC1]]#1, %[[GENERIC1]]#0 :
+//  CHECK-SAME:         outs(%[[ITERARG0_SLICE]] :
+//       CHECK:     scf.forall.in_parallel {
+//       CHECK:       tensor.parallel_insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0]
+//       CHECK:     }
+//       CHECK:   return %[[RESULT]]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
index 3d353c068a9f95..7356c11e85ac0c 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
@@ -57,3 +57,4 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:     %[[INSERT0:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0]
 //      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
 //      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]]
+//      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#0
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir
new file mode 100644
index 00000000000000..8fc8f3245be159
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s
+
+func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>,
+    %init0 : tensor<?x?xf32>, %init1 : tensor<?x?xf32>)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
+  %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm0 = linalg.matmul
+      ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
+  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm1 = linalg.matmul
+      ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %mm1, %mm2 = transform.split_handle %matmuls
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_and_yield %mm2 [10] use_forall true
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @gemm_gemm_fusion_yield_both(
+// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME:     %[[INIT0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[RESULT:.+]]:2 = scf.forall (%[[IV:[a-zA-Z0-9]+]]) =
+// CHECK-SAME:       shared_outs(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
+//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
+//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
+//      CHECK:     %[[FILL0_TILE:.+]] = linalg.fill
+// CHECK-SAME:         outs(%[[INIT0_TILE]] :
+//      CHECK:     %[[GEMM0_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME:         outs(%[[FILL0_TILE]] :
+//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
+//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[FILL1_TILE:.+]] = linalg.fill
+// CHECK-SAME:         outs(%[[INIT1_TILE]] :
+//      CHECK:     %[[GEMM1_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
+// CHECK-SAME:         outs(%[[FILL1_TILE]] :
+//      CHECK:     scf.forall.in_parallel {
+//      CHECK:       tensor.parallel_insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:       tensor.parallel_insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
+//      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#0
diff --git a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
index 05b7afdf0d1ca4..ba56206f03d767 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
@@ -79,7 +79,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     else
 //       CHECK:       %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
 //       CHECK:       %[[PAD:.*]] = tensor.pad %[[SLICE]] low[3, %{{.*}}] high[{{.*}}, {{.*}}]
-//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][%[[C0]], {{.*}}] [%[[DIM0]], {{.*}}] [1, 1]
+//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [%[[DIM0]], {{.*}}] [1, 1]
 //       CHECK:   return %[[RESULT]]
 
 // -----
@@ -143,7 +143,6 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-SAME:     %[[IN:.*]]: tensor<7x9xf32>
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
-//   CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
 //   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
 //       CHECK:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
 //       CHECK:     %[[SWAP_RESULT:.*]] = scf.if
@@ -151,7 +150,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     else
 //       CHECK:       %[[SLICE:.*]] = tensor.extract_slice %[[IN]][0, {{.*}}] [7, {{.*}}] [1, 1]
 //       CHECK:       %[[PAD:.*]] = tensor.pad %[[SLICE]] low[3, %{{.*}}] high[5, {{.*}}]
-//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][%[[C0]], {{.*}}] [%[[C15]], {{.*}}] [1, 1]
+//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [15, {{.*}}] [1, 1]
 //       CHECK:   return %[[RESULT]]
 
 /// Rest of the tests only check that they dont fail.
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 444232e9e1e2e1..607836faafb71d 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -30,10 +30,10 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
 //      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
-//  CHECK-DAG:     %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
 //  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
+//  CHECK-DAG:       %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
 //      CHECK:       %[[TS_X:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
 //  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
 // CHECK-SAME:           [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
@@ -82,13 +82,13 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
 //  CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
 //      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
-//  CHECK-DAG:     %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
 //  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
-//  CHECK-DAG:       %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
 //  CHECK-DAG:       %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
-//      CHECK:         %[[TS_K:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[K]]]
+//  CHECK-DAG:         %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
+//  CHECK-DAG:         %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
+//  CHECK-DAG:         %[[TS_K:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[K]]]
 //  CHECK-DAG:         %[[LHS_TILE:.+]] = memref.subview %[[ARG0]]
 // CHECK-SAME:             [%[[IV0]], %[[IV2]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
 //  CHECK-DAG:         %[[RHS_TILE:.+]] = memref.subview %[[ARG1]]
@@ -137,11 +137,11 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty()
 //       CHECK:   %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
 //  CHECK-SAME:       iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
-//   CHECK-DAG:     %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
 //   CHECK-DAG:     %[[C300:.+]] = arith.constant 300 : index
 //   CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //       CHECK:     %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
 //  CHECK-SAME:         iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
+//   CHECK-DAG:       %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
 //   CHECK-DAG:       %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
 //  CHECK-SAME:           [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
 //   CHECK-DAG:       %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
@@ -203,14 +203,14 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
 //      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[INIT]])
-//  CHECK-DAG:     %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]]
 //  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
-//  CHECK-DAG:       %[[TS_Q:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[Q]]]
 //  CHECK-DAG:       %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]]
 // CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
+//  CHECK-DAG:         %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]]
+//  CHECK-DAG:         %[[TS_Q:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[Q]]]
 //  CHECK-DAG:         %[[TS_C:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[C]]]
 //  CHECK-DAG:         %[[TS_H:.+]] = affine.apply #[[$MAP3]](%[[TS_P]])[%[[R]]]
 //  CHECK-DAG:         %[[TS_W:.+]] = affine.apply #[[$MAP4]](%[[TS_Q]])[%[[S]]]
@@ -302,14 +302,14 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
 //      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
-//  CHECK-DAG:     %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]]
 //  CHECK-DAG:     %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:     %[[INNER1:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
-//  CHECK-DAG:       %[[TS_K:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[K]]]
 //  CHECK-DAG:       %[[C10:.+]] = arith.constant 10 : index
 //      CHECK:       %[[INNER2:[a-zA-Z0-9]+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
 // CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
+//  CHECK-DAG:         %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]]
+//  CHECK-DAG:         %[[TS_K:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[K]]]
 //  CHECK-DAG:         %[[TS_M:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[M]]]
 //  CHECK-DAG:         %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
 // CHECK-SAME:             [%[[IV2]], %[[IV1]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
index db0c1327e2fe02..c5aff744b57ee6 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -2,7 +2,7 @@
 
 func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
     %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = linalg.matmul 
+  %0 = linalg.matmul
     ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
@@ -12,7 +12,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.tile_using_forall %matmul [10, 20] mapping [#gpu.block<y>, #gpu.block<x>]
+    %a, %b = transform.test.tile_using_forall %matmul [10, 20] mapping = [#gpu.block<y>, #gpu.block<x>]
       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
@@ -49,6 +49,48 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @simple_matmul_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+    %arg2 : memref<?x?xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.tile_using_forall %matmul [10, 20]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+//  CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+//      CHECK-LABEL: func.func @simple_matmul_memref(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//      CHECK:   scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = (0, 0) to (%[[M]], %[[N]]) step (10, 20) {
+//  CHECK-DAG:     %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
+//  CHECK-DAG:     %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
+//  CHECK-DAG:     %[[LHS_TILE:.+]] = memref.subview %[[ARG0]]
+// CHECK-SAME:         [%[[IV0]], 0] [%[[TS_M]], %[[K]]] [1, 1]
+//  CHECK-DAG:     %[[RHS_TILE:.+]] = memref.subview %[[ARG1]]
+// CHECK-SAME:         [0, %[[IV1]]] [%[[K]], %[[TS_N]]] [1, 1]
+//  CHECK-DAG:     %[[OUT_TILE:.+]] = memref.subview %[[ARG2]]
+// CHECK-SAME:         [%[[IV0]], %[[IV1]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:     linalg.matmul
+// CHECK-SAME:             ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:             outs(%[[OUT_TILE]] :
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
@@ -203,3 +245,107 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:   %[[INDEX1:.+]] = linalg.index 1
 //       CHECK:   %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
 //       CHECK:   arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
+
+// -----
+
+func.func @interchange_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.tile_using_forall %matmul [10, 20] interchange = [1, 0] mapping = [#gpu.block<y>, #gpu.block<x>]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+//  CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+//      CHECK-LABEL: func.func @interchange_matmul(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]
+// CHECK-SAME:       (0, 0) to (%[[N]], %[[M]]) step (20, 10)
+// CHECK-SAME:       shared_outs(%[[INIT0:.+]] = %[[ARG2]])
+//  CHECK-DAG:     %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]]
+//  CHECK-DAG:     %[[TS_M:.+]] = affine.min #[[$MAP2]](%[[IV1]])[%[[M]]]
+//  CHECK-DAG:     %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:         [%[[IV1]], 0] [%[[TS_M]], %[[K]]] [1, 1]
+//  CHECK-DAG:     %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME:         [0, %[[IV0]]] [%[[K]], %[[TS_N]]] [1, 1]
+//  CHECK-DAG:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT0]]
+// CHECK-SAME:         [%[[IV1]], %[[IV0]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:     %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:         outs(%[[INIT_TILE]] :
+//      CHECK:     scf.forall.in_parallel {
+//      CHECK:       tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT0]]
+// CHECK-SAME:           [%[[IV1]], %[[IV0]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:     } {mapping = [#gpu.block<y>, #gpu.block<x>]}
+//      CHECK:   return %[[OUTER]]
+
+// -----
+
+func.func @check_scalar_operation(%arg0 : tensor<f32>) -> tensor<f32> {
+  %init = tensor.empty() : tensor<f32>
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+      iterator_types = []}      
+      ins(%arg0 : tensor<f32>) outs(%init : tensor<f32>){
+    ^bb0(%b0 : f32, %b1 : f32):
+      %1 = arith.mulf %b0, %b0 : f32
+      linalg.yield %1 : f32
+  } -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a = transform.test.tile_using_forall %generic []
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @check_scalar_operation
+//   CHECK-NOT:   scf.for
+//       CHECK:   linalg.generic
+
+// -----
+
+func.func @check_scalar_memref_operation(%arg0 : memref<f32>, %arg1 : memref<f32>){
+  linalg.generic {
+      indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+      iterator_types = []}      
+      ins(%arg0 : memref<f32>) outs(%arg1 : memref<f32>){
+    ^bb0(%b0 : f32, %b1 : f32):
+      %1 = arith.mulf %b0, %b0 : f32
+      linalg.yield %1 : f32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a = transform.test.tile_using_forall %generic []
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @check_scalar_memref_operation
+//   CHECK-NOT:   scf.for
+//       CHECK:   linalg.generic
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index cc450f45649516..232da2726761b8 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -58,8 +58,8 @@ static LogicalResult
 applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
                       Range &&payloadOps, unsigned numLoops,
                       ArrayRef<OpFoldResult> tileSizes,
-                      ArrayRef<int64_t> interchange,
-                      transform::TransformResults &transformResults) {
+                      ArrayRef<int64_t> interchange, bool useForall,
+                      TransformResults &transformResults) {
   SmallVector<Operation *> tiledOps;
   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
 
@@ -82,6 +82,9 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
 
     scf::SCFTilingOptions tilingOptions;
     tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+    if (useForall) {
+      tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
+    }
 
     scf::SCFTileAndFuseOptions tileAndFuseOptions;
     tileAndFuseOptions.setTilingOptions(tilingOptions);
@@ -97,8 +100,8 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
 
     rewriter.setInsertionPoint(target);
     FailureOr<scf::SCFTileAndFuseResult> tiledResults =
-        scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
-            rewriter, tilingInterfaceOp, tileAndFuseOptions);
+        scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
+                                                  tileAndFuseOptions);
     if (failed(tiledResults))
       return failure();
 
@@ -109,12 +112,11 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
       for (OpResult res : toReplace->getResults())
         if (auto replacement = tiledResults->replacements.lookup(res)) {
           Operation *replacementOp = replacement.getDefiningOp();
-          rewriter.replaceUsesWithIf(
-              res, replacement, [&](mlir::OpOperand &use) {
-                Operation *user = use.getOwner();
-                return dominanceInfo.properlyDominates(replacementOp, user) &&
-                       user->getParentOp() == replacementOp->getParentOp();
-              });
+          rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
+            Operation *user = use.getOwner();
+            return dominanceInfo.properlyDominates(replacementOp, user) &&
+                   user->getParentOp() == replacementOp->getParentOp();
+          });
         }
 
       if (toReplace->use_empty()) {
@@ -138,10 +140,10 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
   return success();
 }
 
-DiagnosedSilenceableFailure transform::TestFuseAndYieldOp::apply(
-    transform::TransformRewriter &rewriter,
-    mlir::transform::TransformResults &transformResults,
-    mlir::transform::TransformState &state) {
+DiagnosedSilenceableFailure
+transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
+                                     TransformResults &transformResults,
+                                     TransformState &state) {
   SmallVector<int64_t> tileSizes =
       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
   SmallVector<int64_t> tileInterchange =
@@ -153,7 +155,7 @@ DiagnosedSilenceableFailure transform::TestFuseAndYieldOp::apply(
   LogicalResult result = applyTileAndFuseToAll(
       rewriter, getOperation(), state.getPayloadOps(getTarget()),
       tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
-      tileInterchange, transformResults);
+      tileInterchange, getUseForall(), transformResults);
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
 }
@@ -169,7 +171,7 @@ static LogicalResult
 applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
                Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
                ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
-               transform::TransformResults &transformResults) {
+               TransformResults &transformResults) {
   SmallVector<Operation *> tiledOps;
   SmallVector<Operation *> loopOps;
 
@@ -186,10 +188,11 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
           });
       tilingOptions.setMapping(mappingAttrs);
     }
+    tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
 
     rewriter.setInsertionPoint(target);
     FailureOr<scf::SCFTilingResult> tiledResults =
-        scf::tileUsingSCFForallOp(rewriter, tilingInterfaceOp, tilingOptions);
+        scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
     if (failed(tiledResults))
       return failure();
 
@@ -209,10 +212,10 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
   return success();
 }
 
-DiagnosedSilenceableFailure transform::TestTileUsingForallOp::apply(
-    transform::TransformRewriter &rewriter,
-    mlir::transform::TransformResults &transformResults,
-    mlir::transform::TransformState &state) {
+DiagnosedSilenceableFailure
+transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
   SmallVector<int64_t> tileSizes =
       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
   SmallVector<int64_t> interchange =
@@ -235,6 +238,96 @@ void transform::TestTileUsingForallOp::getEffects(
   modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// TestFuseUsingForallOp
+//===----------------------------------------------------------------------===//
+
+/// Apply a tiling transformation to all payload ops and store both the
+/// tiled operation as well as the created tile loops.
+template <typename Range>
+static LogicalResult applyTilingToAll(
+    RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
+    unsigned numLoops, TransformResults &transformResults,
+    function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
+        applyFn) {
+  SmallVector<Operation *> tiledLinalgOps;
+  SmallVector<SmallVector<Operation *>> loopOps(1);
+
+  for (Operation *target : payloadOps) {
+    auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+    if (!tilingInterfaceOp)
+      return transformOp->emitError("only TilingInterface ops are supported");
+
+    rewriter.setInsertionPoint(target);
+    FailureOr<scf::SCFTileAndFuseResult> tiledResults =
+        applyFn(tilingInterfaceOp);
+    if (failed(tiledResults))
+      return failure();
+
+    // Perform the replacement of tiled and fused values.
+    SmallVector<Operation *> opsToReplace{target};
+    llvm::append_range(opsToReplace, tiledResults->fusedProducers);
+    for (Operation *toReplace : opsToReplace) {
+      for (OpResult res : toReplace->getResults())
+        if (auto replacement = tiledResults->replacements.lookup(res))
+          rewriter.replaceAllUsesWith(res, replacement);
+      if (toReplace->use_empty())
+        rewriter.eraseOp(toReplace);
+    }
+
+    // Report back the relevant handles to the transform op.
+    tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
+    assert(tiledResults->loops.size() == 1 &&
+           cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
+           "Mismatched number of loops, tile and fuse transform should have "
+           "failed");
+    loopOps[0].push_back({tiledResults->loops[0]});
+  }
+
+  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
+  if (!loopOps.empty())
+    transformResults.set(transformOp->getOpResult(1), loopOps[0]);
+
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
+  SmallVector<int64_t> tileSizes =
+      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
+  SmallVector<int64_t> tileInterchange =
+      extractFromIntegerArrayAttr<int64_t>(getInterchange());
+
+  scf::SCFTilingOptions tilingOptions;
+  tilingOptions.interchangeVector = tileInterchange;
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
+  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
+  tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
+  scf::SCFTileAndFuseOptions tileAndFuseOptions;
+  tileAndFuseOptions.tilingOptions = tilingOptions;
+  LogicalResult result = applyTilingToAll(
+      rewriter, getOperation(), state.getPayloadOps(getRootOp()),
+      tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
+      [&](TilingInterface tilingInterfaceOp)
+          -> FailureOr<scf::SCFTileAndFuseResult> {
+        return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
+                                                    tileAndFuseOptions);
+      });
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseUsingForallOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getRootOp(), effects);
+  producesHandle(getTiledOps(), effects);
+  producesHandle(getLoops(), effects);
+  modifiesPayload(effects);
+}
+
 #define GET_OP_CLASSES
 #include "TestTilingInterfaceTransformOps.cpp.inc"
 
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 6e9354198896ab..9760eb70fafb99 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -37,13 +37,15 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
   let arguments =
     (ins TransformHandleTypeInterface:$target,
         DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
-        DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
+        DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
+        DefaultValuedAttr<BoolAttr, "{false}">:$use_forall);
   let results = (outs TransformHandleTypeInterface:$transfomed,
       Variadic<TransformHandleTypeInterface>:$loops);
 
   let assemblyFormat = [{
     $target ($tile_sizes^)? (`interchange` $tile_interchange^)?
-    attr-dict `:` functional-type(operands, results)
+    (`use_forall` $use_forall^)? attr-dict 
+    `:` functional-type(operands, results)
   }];
 }
 
@@ -71,11 +73,33 @@ def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall",
                       Variadic<TransformHandleTypeInterface>:$loops);
 
   let assemblyFormat = [{
-    $target ($tile_sizes^)? (`interchange` $interchange^)?
-    (`mapping` $mapping^)?
+    $target ($tile_sizes^)? (`interchange` `=` $interchange^)?
+    (`mapping` `=` $mapping^)?
     attr-dict `:` functional-type(operands, results)
   }];
 }
 
+def TestFuseUsingForallOp : Op<Transform_Dialect, "test.fuse_using_forall",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Test operation to tile the operation pointed to by the target handle and
+    fuses their producers greedily using the options provided as attributes.
+    This operation uses scf.forall for the loop construct.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$root_op,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+                   DefaultValuedOptionalAttr<I64ArrayAttr, "{}">:$interchange,
+                   OptionalAttr<DeviceMappingArrayAttr>:$mapping);
+  let results = (outs TransformHandleTypeInterface:$tiled_ops,
+                      Variadic<TransformHandleTypeInterface>:$loops);
+
+  let assemblyFormat = [{
+    $root_op ($tile_sizes^)? (`interchange` $interchange^)?
+    (`mapping` `=` $mapping^)?
+    attr-dict `:` functional-type(operands, results)
+  }];
+}
 
 #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS

>From 0b91c982c3d0dcd0c0d32e14997939543e00806a Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 17 Jan 2024 23:44:52 -0800
Subject: [PATCH 02/12] Make `getYieldedValuesMutable` return an
 `std::optional<llvm::MutableArrayRef>`.

---
 .../mlir/Interfaces/LoopLikeInterface.td      | 27 ++++++++++++-------
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  2 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  4 +--
 3 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 20afc35571fbf2..83746bf04ddf3d 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -164,13 +164,16 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     InterfaceMethod<[{
         Return the mutable operand range of values that are yielded to the next
         iteration by the loop terminator.
+
+        For loop operations that dont yield a value, this should return
+        std::nullopt.
       }],
-      /*retTy=*/"::llvm::MutableArrayRef<::mlir::OpOperand>",
+      /*retTy=*/"std::optional<::llvm::MutableArrayRef<::mlir::OpOperand>>",
       /*methodName=*/"getYieldedValuesMutable",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return {};
+        return std::nullopt;
       }]
     >,
     InterfaceMethod<[{
@@ -257,16 +260,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
           });
     }
 
-    /// Return the values that are yielded to the next iteration.
+    /// Return the values that are yielded to the next iteration. If
+    /// the loop doesnt yield any values return `{}`.
     ::mlir::ValueRange getYieldedValues() {
       auto mutableValues = $_op.getYieldedValuesMutable();
-      if (mutableValues.empty())
+      if (!mutableValues || mutableValues->empty())
         return {};
-      Operation *yieldOp = mutableValues.begin()->getOwner();
-      unsigned firstOperandIndex = mutableValues.begin()->getOperandNumber();
+      Operation *yieldOp = mutableValues->begin()->getOwner();
+      unsigned firstOperandIndex = mutableValues->begin()->getOperandNumber();
       return OperandRange(
           yieldOp->operand_begin() + firstOperandIndex,
-          yieldOp->operand_begin() + firstOperandIndex + mutableValues.size());
+          yieldOp->operand_begin() + firstOperandIndex + mutableValues->size());
     }
 
     /// Return the "init" operands that are used as initialization values for
@@ -331,14 +335,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
 
     /// Return the yielded value that corresponds to the given region iter_arg.
     /// Return "nullptr" if the given block argument is not a region iter_arg
-    /// of this loop op.
+    /// of this loop op or if there is no yield corresponding to this `bbArg`.
     OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
       auto iterArgs = $_op.getRegionIterArgs();
       auto it = llvm::find(iterArgs, bbArg);
       if (it == iterArgs.end())
         return {};
-      return
-          &$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
+      std::optional<llvm::MutableArrayRef<::mlir::OpOperand>> yieldValues =
+        $_op.getYieldedValuesMutable();
+      if (!yieldValues)
+        return {};
+      return &yieldValues.value()[std::distance(iterArgs.begin(), it)];
     }
 
     /// Return the loop result that corresponds to the given init operand.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d5be2e906989fa..84c5accb1fbfe5 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2127,7 +2127,7 @@ unsigned AffineForOp::getNumIterOperands() {
   return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
 }
 
-MutableArrayRef<OpOperand> AffineForOp::getYieldedValuesMutable() {
+std::optional<MutableArrayRef<OpOperand>> AffineForOp::getYieldedValuesMutable() {
   return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 31101861ad6f45..0ffbd8e2a3f17f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1212,7 +1212,7 @@ std::optional<APInt> ForOp::getConstantStep() {
   return {};
 }
 
-MutableArrayRef<OpOperand> ForOp::getYieldedValuesMutable() {
+std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
   return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
 }
 
@@ -3222,7 +3222,7 @@ YieldOp WhileOp::getYieldOp() {
   return cast<YieldOp>(getAfterBody()->getTerminator());
 }
 
-MutableArrayRef<OpOperand> WhileOp::getYieldedValuesMutable() {
+std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
   return getYieldOp().getResultsMutable();
 }
 

>From bc202a555b4b7bd100b3b466f0c39ccae821e609 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 17 Jan 2024 23:45:45 -0800
Subject: [PATCH 03/12] Add description for `yieldTiledValuesAndReplace`.

---
 .../mlir/Interfaces/LoopLikeInterface.td      | 36 ++++++++++++++++++-
 1 file changed, 35 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 83746bf04ddf3d..2f4dbf572ce35e 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -222,7 +222,41 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       }]
     >,
     InterfaceMethod<[{
-        TODO
+        Append the specified additional "init" operands: replace this loop with
+        a new loop that has the additional init operands. The loop body of
+        this loop is moved over to the new loop.
+
+        This method is similar to `replaceWithAdditionalYields` but instead of
+        returning the value that is actually yielded, this returns the tiles of
+        the values that are yielded. This allows for unified handling of opreations
+        like `scf.forall` which dont yield a value from the loop, but instead
+        the terminator specifies where to insert the tile yielded by the body of
+        the loop. For example,
+        
+        ```mlir
+        %0 = scf.forall ... shared_outs(%arg0 = %arg1) {
+          ...
+          %tiled_value
+          scf.forall.in_parallel {
+            tensor.parallel_insert_slice %tiled_value into %arg0[%o1, %o2]...
+          }
+        }
+        ```
+
+        For an `scf.for` the same computation would be represented as
+        ```mlir
+        %0 = scf.for ... iter_args(%arg0 = %arg1) {
+          ...
+          %tiled_value
+          %insert = tensor.insert_slice %tiled_value into %arg0[%o1, %o2]...
+          scf.yield %insert
+        }
+        ```
+
+        So for the caller, the tiled value (`%tiled_values`) and the offsets
+        `(%o1, %o2)` and sizes (not shown) are generated the same way, but
+        the implementation method for the different loop constructs handles
+        the difference in representation.
       }],
       /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
       /*methodName=*/"yieldTiledValuesAndReplace",

>From d5790138d6343afe136fd9615df8e300fc63bfa5 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 17 Jan 2024 23:46:28 -0800
Subject: [PATCH 04/12] Attempt to fix Windows build error.

---
 .../TilingInterface/TestTilingInterfaceTransformOps.cpp         | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 232da2726761b8..b6a0ad84eee011 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -281,7 +281,7 @@ static LogicalResult applyTilingToAll(
            cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
            "Mismatched number of loops, tile and fuse transform should have "
            "failed");
-    loopOps[0].push_back({tiledResults->loops[0]});
+    loopOps[0] = {tiledResults->loops[0]};
   }
 
   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);

>From 4f62eda6303210513502f49492e2294d40e857e4 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 17 Jan 2024 23:49:50 -0800
Subject: [PATCH 05/12] Address comments (round 1).

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0ffbd8e2a3f17f..29a09a7050857a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -596,7 +596,7 @@ FailureOr<LoopLikeOpInterface> ForOp::yieldTiledValuesAndReplace(
 
   auto inits = llvm::to_vector(getInitArgs());
   inits.append(newInitOperands.begin(), newInitOperands.end());
-  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+  auto newLoop = rewriter.create<ForOp>(
       getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
       [](OpBuilder &, Location, Value, ValueRange) {});
 

>From ddf2883378c6ee935587383dca7721d965cb3ae5 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 17 Jan 2024 23:57:27 -0800
Subject: [PATCH 06/12] Fix flang errors due to change in
 `getYieldedValuesMutable` signature.

---
 flang/lib/Optimizer/Dialect/FIROps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 6d62e470706e53..f826f2566b897a 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1935,7 +1935,7 @@ mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) {
   return {};
 }
 
-llvm::MutableArrayRef<mlir::OpOperand>
+std::optional<llvm::MutableArrayRef<mlir::OpOperand>>
 fir::IterWhileOp::getYieldedValuesMutable() {
   auto *term = getRegion().front().getTerminator();
   return getFinalValue() ? term->getOpOperands().drop_front()
@@ -2247,7 +2247,7 @@ mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) {
   return {};
 }
 
-llvm::MutableArrayRef<mlir::OpOperand>
+std::optional<llvm::MutableArrayRef<mlir::OpOperand>>
 fir::DoLoopOp::getYieldedValuesMutable() {
   auto *term = getRegion().front().getTerminator();
   return getFinalValue() ? term->getOpOperands().drop_front()

>From 63cd5033f29ad6a8341412a4b358e8bb8f8b9677 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 17 Jan 2024 23:59:17 -0800
Subject: [PATCH 07/12] Fix code-formatting error.

---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 84c5accb1fbfe5..cfe71d1d836bb6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2127,7 +2127,8 @@ unsigned AffineForOp::getNumIterOperands() {
   return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
 }
 
-std::optional<MutableArrayRef<OpOperand>> AffineForOp::getYieldedValuesMutable() {
+std::optional<MutableArrayRef<OpOperand>>
+AffineForOp::getYieldedValuesMutable() {
   return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
 }
 

>From 38889dc0a205351a5c59b3610de4ad348556d6fd Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 18 Jan 2024 00:05:32 -0800
Subject: [PATCH 08/12] Fix documentation errors.

---
 mlir/include/mlir/Interfaces/LoopLikeInterface.td | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 2f4dbf572ce35e..82ba99b59948ef 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -227,16 +227,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         this loop is moved over to the new loop.
 
         This method is similar to `replaceWithAdditionalYields` but instead of
-        returning the value that is actually yielded, this returns the tiles of
-        the values that are yielded. This allows for unified handling of opreations
-        like `scf.forall` which dont yield a value from the loop, but instead
-        the terminator specifies where to insert the tile yielded by the body of
+        yielding a value from within the loop, it allows each loop construct
+        implementing this method to handle the result of each iteration
+        appropriately. This allows for unified handling of operations
+        like `scf.forall` which don't yield a value from the loop, but instead
+        the terminator specifies where to insert the tile computed by the body of
         the loop. For example,
         
         ```mlir
         %0 = scf.forall ... shared_outs(%arg0 = %arg1) {
           ...
-          %tiled_value
+          %tiled_value = ...
           scf.forall.in_parallel {
             tensor.parallel_insert_slice %tiled_value into %arg0[%o1, %o2]...
           }
@@ -247,13 +248,13 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         ```mlir
         %0 = scf.for ... iter_args(%arg0 = %arg1) {
           ...
-          %tiled_value
+          %tiled_value = ...
           %insert = tensor.insert_slice %tiled_value into %arg0[%o1, %o2]...
           scf.yield %insert
         }
         ```
 
-        So for the caller, the tiled value (`%tiled_values`) and the offsets
+        So for the caller, the tiled value (`%tiled_value`) and the offsets
         `(%o1, %o2)` and sizes (not shown) are generated the same way, but
         the implementation method for the different loop constructs handles
         the difference in representation.

>From ed64536f8021ecccfd875b71f3918314cc644a2b Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 18 Jan 2024 00:12:22 -0800
Subject: [PATCH 09/12] Use std::function_ref.

---
 mlir/include/mlir/Interfaces/LoopLikeInterface.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index c62476f9b62256..ca7c86cff81da9 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -22,7 +22,7 @@ class RewriterBase;
 /// `replaceWithAdditionalYields`. `newBbArgs` are the newly added region
 /// iter_args. This function should return as many values as there are block
 /// arguments in `newBbArgs`.
-using NewYieldValuesFn = std::function<SmallVector<Value>(
+using NewYieldValuesFn = llvm::function_ref<SmallVector<Value>(
     OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
 
 /// A function that allows returning additional yielded values during
@@ -41,7 +41,7 @@ using NewYieldValuesFn = std::function<SmallVector<Value>(
 /// - `resultStrides` is of the same size as `tiledValues` and represents
 ///   the strides to use when inserting corresponding element from `tiledValues`
 ///   into the element from `newBbArgs`.
-using YieldTiledValuesFn = std::function<LogicalResult(
+using YieldTiledValuesFn = llvm::function_ref<LogicalResult(
     RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
     SmallVector<Value> &tiledValues,
     SmallVector<SmallVector<OpFoldResult>> &resultOffsets,

>From 3f51bc2066c502893b2b2dc70c0e5925397b1372 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 18 Jan 2024 15:24:19 -0800
Subject: [PATCH 10/12] Address comments (round 2)

- Remove passing as `const &` for the `llvm::function_ref`.
- Fix ordering of interface methods.
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td        |  9 +++++----
 mlir/include/mlir/Interfaces/LoopLikeInterface.td |  2 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp                   | 14 ++++++++------
 3 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 08caaa0b880b45..b3d085bfff1af9 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -135,10 +135,11 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
-       ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
-        "getSingleLowerBound", "getSingleStep", "getSingleUpperBound",
-        "getYieldedValuesMutable", "getLoopResults", "promoteIfSingleIteration",
-        "replaceWithAdditionalYields", "yieldTiledValuesAndReplace"]>,
+       ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+        "getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
+        "getSingleUpperBound", "getYieldedValuesMutable",
+        "promoteIfSingleIteration", "replaceWithAdditionalYields",
+        "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 82ba99b59948ef..f2afd0ec661de4 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -263,7 +263,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*methodName=*/"yieldTiledValuesAndReplace",
       /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
                     "::mlir::ValueRange":$newInitOperands,
-                    "const ::mlir::YieldTiledValuesFn &":$yieldTiledValuesFn),
+                    "::mlir::YieldTiledValuesFn":$yieldTiledValuesFn),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return ::mlir::failure();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 29a09a7050857a..9c1515319747a6 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -588,9 +588,10 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
   return cast<LoopLikeOpInterface>(newLoop.getOperation());
 }
 
-FailureOr<LoopLikeOpInterface> ForOp::yieldTiledValuesAndReplace(
-    RewriterBase &rewriter, ValueRange newInitOperands,
-    const YieldTiledValuesFn &yieldTiledValuesFn) {
+FailureOr<LoopLikeOpInterface>
+ForOp::yieldTiledValuesAndReplace(RewriterBase &rewriter,
+                                  ValueRange newInitOperands,
+                                  YieldTiledValuesFn yieldTiledValuesFn) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(getOperation());
 
@@ -691,9 +692,10 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
   return getOutputsMutable();
 }
 
-FailureOr<LoopLikeOpInterface> ForallOp::yieldTiledValuesAndReplace(
-    RewriterBase &rewriter, ValueRange newInitOperands,
-    const YieldTiledValuesFn &yieldTiledValuesFn) {
+FailureOr<LoopLikeOpInterface>
+ForallOp::yieldTiledValuesAndReplace(RewriterBase &rewriter,
+                                     ValueRange newInitOperands,
+                                     YieldTiledValuesFn yieldTiledValuesFn) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(getOperation());
   auto inits = llvm::to_vector(getOutputs());

>From 20271d9fd2ee05b2375d76f19eacd5c37c57d4f6 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 18 Jan 2024 17:23:06 -0800
Subject: [PATCH 11/12] Drop `yieldTiledValuesAndReplace` as an interface
 method.

---
 .../mlir/Interfaces/LoopLikeInterface.td      |  50 +-----
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 106 -----------
 .../SCF/Transforms/TileUsingInterface.cpp     | 170 +++++++++++++++++-
 3 files changed, 169 insertions(+), 157 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index f2afd0ec661de4..e2ac85a3f7725d 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -220,55 +220,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/[{
         return ::mlir::failure();
       }]
-    >,
-    InterfaceMethod<[{
-        Append the specified additional "init" operands: replace this loop with
-        a new loop that has the additional init operands. The loop body of
-        this loop is moved over to the new loop.
-
-        This method is similar to `replaceWithAdditionalYields` but instead of
-        yielding a value from within the loop, it allows each loop construct
-        implementing this method to handle the result of each iteration
-        appropriately. This allows for unified handling of operations
-        like `scf.forall` which don't yield a value from the loop, but instead
-        the terminator specifies where to insert the tile computed by the body of
-        the loop. For example,
-        
-        ```mlir
-        %0 = scf.forall ... shared_outs(%arg0 = %arg1) {
-          ...
-          %tiled_value = ...
-          scf.forall.in_parallel {
-            tensor.parallel_insert_slice %tiled_value into %arg0[%o1, %o2]...
-          }
-        }
-        ```
-
-        For an `scf.for` the same computation would be represented as
-        ```mlir
-        %0 = scf.for ... iter_args(%arg0 = %arg1) {
-          ...
-          %tiled_value = ...
-          %insert = tensor.insert_slice %tiled_value into %arg0[%o1, %o2]...
-          scf.yield %insert
-        }
-        ```
-
-        So for the caller, the tiled value (`%tiled_value`) and the offsets
-        `(%o1, %o2)` and sizes (not shown) are generated the same way, but
-        the implementation method for the different loop constructs handles
-        the difference in representation.
-      }],
-      /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
-      /*methodName=*/"yieldTiledValuesAndReplace",
-      /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                    "::mlir::ValueRange":$newInitOperands,
-                    "::mlir::YieldTiledValuesFn":$yieldTiledValuesFn),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return ::mlir::failure();
-      }]
-    >,
+    >
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9c1515319747a6..f0f60294ea7bf5 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -588,64 +588,6 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
   return cast<LoopLikeOpInterface>(newLoop.getOperation());
 }
 
-FailureOr<LoopLikeOpInterface>
-ForOp::yieldTiledValuesAndReplace(RewriterBase &rewriter,
-                                  ValueRange newInitOperands,
-                                  YieldTiledValuesFn yieldTiledValuesFn) {
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(getOperation());
-
-  auto inits = llvm::to_vector(getInitArgs());
-  inits.append(newInitOperands.begin(), newInitOperands.end());
-  auto newLoop = rewriter.create<ForOp>(
-      getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
-      [](OpBuilder &, Location, Value, ValueRange) {});
-
-  // Move the loop body to the new op.
-  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
-                       newLoop.getBody()->getArguments().take_front(
-                           getBody()->getNumArguments()));
-
-  auto yieldOp = cast<scf::YieldOp>(newLoop.getBody()->getTerminator());
-  rewriter.setInsertionPoint(yieldOp);
-
-  SmallVector<Value> tiledValues;
-  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
-  ValueRange newRegionIterArgs =
-      newLoop.getRegionIterArgs().take_back(newInitOperands.size());
-  if (failed(yieldTiledValuesFn(rewriter, getLoc(), newLoop.getInductionVar(),
-                                newRegionIterArgs, tiledValues, resultOffsets,
-                                resultSizes))) {
-    return rewriter.notifyMatchFailure(getOperation(),
-                                       "failed to get tiled values");
-  }
-
-  if (tiledValues.size() != resultOffsets.size() ||
-      tiledValues.size() != resultSizes.size()) {
-    return rewriter.notifyMatchFailure(
-        getOperation(),
-        "expected number of tiled values returned, the number of offset "
-        "vectors and number of size vectors to be the same");
-  }
-
-  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
-  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
-       llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
-                       resultSizes)) {
-    SmallVector<OpFoldResult> resultStride(resultOffset.size(),
-                                           rewriter.getIndexAttr(1));
-    Value insert = rewriter.create<tensor::InsertSliceOp>(
-        yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
-        resultStride);
-    newYieldValues.push_back(insert);
-  }
-
-  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
-  rewriter.replaceOp(getOperation(),
-                     newLoop->getResults().take_front(getNumResults()));
-  return cast<LoopLikeOpInterface>(newLoop.getOperation());
-}
-
 ForOp mlir::scf::getForInductionVarOwner(Value val) {
   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
   if (!ivArg)
@@ -692,54 +634,6 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
   return getOutputsMutable();
 }
 
-FailureOr<LoopLikeOpInterface>
-ForallOp::yieldTiledValuesAndReplace(RewriterBase &rewriter,
-                                     ValueRange newInitOperands,
-                                     YieldTiledValuesFn yieldTiledValuesFn) {
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(getOperation());
-  auto inits = llvm::to_vector(getOutputs());
-  inits.append(newInitOperands.begin(), newInitOperands.end());
-  auto newLoop = rewriter.create<scf::ForallOp>(
-      getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
-      inits, getMapping(), [](OpBuilder &, Location, ValueRange) {});
-
-  // Move the region of the current block to the newly created op.
-  Block *newLoopBody = newLoop.getBody();
-  rewriter.mergeBlocks(
-      getBody(), newLoopBody,
-      newLoopBody->getArguments().take_front(getBody()->getNumArguments()));
-
-  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
-  rewriter.setInsertionPoint(terminator);
-  SmallVector<Value> tiledValues;
-  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
-  ValueRange regionIterArgs =
-      newLoop.getRegionIterArgs().take_back(newInitOperands.size());
-  if (failed(yieldTiledValuesFn(rewriter, getLoc(), newLoop.getInductionVars(),
-                                regionIterArgs, tiledValues, resultOffsets,
-                                resultSizes))) {
-    return rewriter.notifyMatchFailure(getOperation(),
-                                       "failed to get yielded tiled values");
-  }
-
-  // Update the terminator.
-  rewriter.setInsertionPointToEnd(terminator.getBody());
-
-  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
-           tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
-    SmallVector<OpFoldResult> resultStride(resultOffset.size(),
-                                           rewriter.getIndexAttr(1));
-    rewriter.create<tensor::ParallelInsertSliceOp>(
-        terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
-        resultStride);
-  }
-
-  rewriter.replaceOp(getOperation(),
-                     newLoop->getResults().take_front(getNumResults()));
-  return cast<LoopLikeOpInterface>(newLoop.getOperation());
-}
-
 /// Promotes the loop body of a scf::ForallOp to its containing block.
 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
   OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 50a85e6e34e240..fa842d1058a332 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
 
@@ -287,6 +288,171 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
 
+/// A function that allows returning additional yielded values during
+/// `yieldTiledValuesAndReplace`.
+/// - `ivs` induction variable for the loop.
+/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
+/// - `tiledValues` the tiled values to return. Must be of same size as
+///   `newbbArgs`, each element of this array is inserted into the corresponding
+///   element in `newbbArgs`.
+/// - `resultOffsets` is of the same size as `tiledValues` and represents
+///   the offsets to use when inserting corresponding element from `tiledValues`
+///   into the element from `newBbArgs`.
+/// - `resultSizes` is of the same size as `tiledValues` and represents
+///   the size of the corresponding element from `tiledValues` inserted into
+///   the element from `newBbArgs`.
+using YieldTiledValuesFn = llvm::function_ref<LogicalResult(
+    RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
+    SmallVector<Value> &tiledValues,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
+
+/// Append the specified additional `newInitOperands` operands to the
+/// loops existing `init` operands (or similar), and replace `loopOp` with
+/// the new loop that has the additional init operands. The loop body of
+/// this loop is moved over to the new loop. `yieldTiledValuesFn`
+/// is called to get the new tiled values returned, and the offset
+/// and sizes at which the tiled value is inserted into the
+/// new region iter_args that correspond to the newly added init operands.
+template <typename LoopType>
+FailureOr<LoopLikeOpInterface>
+yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
+                               ValueRange newInitOperands,
+                               YieldTiledValuesFn yieldTiledValuesFn) {
+  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
+}
+
+/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
+template <>
+FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
+    scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
+    YieldTiledValuesFn yieldTiledValuesFn) {
+  OpBuilder::InsertionGuard g(rewriter);
+  Location loc = loopOp.getLoc();
+  rewriter.setInsertionPoint(loopOp);
+
+  auto inits = llvm::to_vector(loopOp.getInitArgs());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  auto newLoop = rewriter.create<scf::ForOp>(
+      loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
+      inits, [](OpBuilder &, Location, Value, ValueRange) {});
+
+  // Move the loop body to the new op.
+  Block *loopBody = loopOp.getBody();
+  Block *newLoopBody = newLoop.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
+  rewriter.setInsertionPoint(yieldOp);
+
+  SmallVector<Value> tiledValues;
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+  ValueRange newRegionIterArgs =
+      newLoop.getRegionIterArgs().take_back(newInitOperands.size());
+  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
+                                newRegionIterArgs, tiledValues, resultOffsets,
+                                resultSizes))) {
+    return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
+  }
+
+  if (tiledValues.size() != resultOffsets.size() ||
+      tiledValues.size() != resultSizes.size()) {
+    return rewriter.notifyMatchFailure(
+        loopOp,
+        "expected number of tiled values returned, the number of offset "
+        "vectors and number of size vectors to be the same");
+  }
+
+  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
+  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
+       llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
+                       resultSizes)) {
+    SmallVector<OpFoldResult> resultStride(resultOffset.size(),
+                                           rewriter.getIndexAttr(1));
+    Value insert = rewriter.create<tensor::InsertSliceOp>(
+        yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
+        resultStride);
+    newYieldValues.push_back(insert);
+  }
+
+  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
+  rewriter.replaceOp(loopOp,
+                     newLoop->getResults().take_front(loopOp.getNumResults()));
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
+/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
+template <>
+FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
+    scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
+    YieldTiledValuesFn yieldTiledValuesFn) {
+  OpBuilder::InsertionGuard g(rewriter);
+  Location loc = loopOp.getLoc();
+  rewriter.setInsertionPoint(loopOp);
+  auto inits = llvm::to_vector(loopOp.getOutputs());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  auto newLoop = rewriter.create<scf::ForallOp>(
+      loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
+      loopOp.getMixedStep(), inits, loopOp.getMapping(),
+      [](OpBuilder &, Location, ValueRange) {});
+
+  // Move the region of the current block to the newly created op.
+  Block *loopBody = loopOp.getBody();
+  Block *newLoopBody = newLoop.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
+  rewriter.setInsertionPoint(terminator);
+  SmallVector<Value> tiledValues;
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+  ValueRange regionIterArgs =
+      newLoop.getRegionIterArgs().take_back(newInitOperands.size());
+  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
+                                regionIterArgs, tiledValues, resultOffsets,
+                                resultSizes))) {
+    return rewriter.notifyMatchFailure(loopOp,
+                                       "failed to get yielded tiled values");
+  }
+
+  // Update the terminator.
+  rewriter.setInsertionPointToEnd(terminator.getBody());
+
+  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
+           tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> resultStride(resultOffset.size(),
+                                           rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
+        resultStride);
+  }
+
+  rewriter.replaceOp(loopOp,
+                     newLoop->getResults().take_front(loopOp.getNumResults()));
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
+/// Implementation of `yieldTiledValuesAndReplaceLoop` for
+/// `LoopLikeOpInterface`, that just dispatches to the implementation for each
+/// supported loop type.
+FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
+    LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
+    ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
+  return TypeSwitch<LoopLikeOpInterface, FailureOr<LoopLikeOpInterface>>(
+             loopLikeOp)
+      .Case<scf::ForOp, scf::ForallOp>(
+          [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
+            return yieldTiledValuesAndReplaceLoop(
+                loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
+          })
+      .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
+        return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
+      });
+}
+
 /// Method to add new init values to a loop nest. Updates `loops` in-place with
 /// new loops that use the `newInitValues`.
 /// The outer-loops are updated to yield the new result values of the inner
@@ -334,8 +500,8 @@ static LogicalResult addInitOperandsToLoopNest(
   // Update the loop body of the innermost loop to get new yield values.
   LoopLikeOpInterface innerMostLoop = loops.back();
   FailureOr<LoopLikeOpInterface> newInnerMostLoop =
-      innerMostLoop.yieldTiledValuesAndReplace(rewriter, newInitValues,
-                                               getNewTiledYieldsFn);
+      yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
+                                     getNewTiledYieldsFn);
 
   if (failed(newInnerMostLoop))
     return innerMostLoop.emitOpError("failed to return additional yields");

>From 09dab87597250a9c27d083847150bccdb69f58ea Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Fri, 19 Jan 2024 13:48:56 -0800
Subject: [PATCH 12/12] Try not using `llvm::function_ref`.

---
 .../mlir/Interfaces/LoopLikeInterface.h       | 24 +---------
 .../SCF/Transforms/TileUsingInterface.cpp     | 44 +++++++++----------
 2 files changed, 23 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index ca7c86cff81da9..7c7d378d0590ab 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -22,31 +22,9 @@ class RewriterBase;
 /// `replaceWithAdditionalYields`. `newBbArgs` are the newly added region
 /// iter_args. This function should return as many values as there are block
 /// arguments in `newBbArgs`.
-using NewYieldValuesFn = llvm::function_ref<SmallVector<Value>(
+using NewYieldValuesFn = std::function<SmallVector<Value>(
     OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
 
-/// A function that allows returning additional yielded values during
-/// `yieldTiledValuesAndReplace`.
-/// - `ivs` induction variable for the loop.
-/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
-/// - `tiledValues` the tiled values to return. Must be of same size as
-///   `newbbArgs`, each element of this array is inserted into the corresponding
-///   element in `newbbArgs`.
-/// - `resultOffsets` is of the same size as `tiledValues` and represents
-///   the offsets to use when inserting corresponding element from `tiledValues`
-///   into the element from `newBbArgs`.
-/// - `resultSizes` is of the same size as `tiledValues` and represents
-///   the size of the corresponding element from `tiledValues` inserted into
-///   the element from `newBbArgs`.
-/// - `resultStrides` is of the same size as `tiledValues` and represents
-///   the strides to use when inserting corresponding element from `tiledValues`
-///   into the element from `newBbArgs`.
-using YieldTiledValuesFn = llvm::function_ref<LogicalResult(
-    RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
-    SmallVector<Value> &tiledValues,
-    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
-    SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
-
 namespace detail {
 /// Verify invariants of the LoopLikeOpInterface.
 LogicalResult verifyLoopLikeOpInterface(Operation *op);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index fa842d1058a332..ad0d0a1d4c734d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -99,6 +99,25 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
 }
 
+/// A function that allows returning additional yielded values during
+/// `yieldTiledValuesAndReplace`.
+/// - `ivs` induction variable for the loop.
+/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
+/// - `tiledValues` the tiled values to return. Must be of same size as
+///   `newbbArgs`, each element of this array is inserted into the corresponding
+///   element in `newbbArgs`.
+/// - `resultOffsets` is of the same size as `tiledValues` and represents
+///   the offsets to use when inserting corresponding element from `tiledValues`
+///   into the element from `newBbArgs`.
+/// - `resultSizes` is of the same size as `tiledValues` and represents
+///   the size of the corresponding element from `tiledValues` inserted into
+///   the element from `newBbArgs`.
+using YieldTiledValuesFn = std::function<LogicalResult(
+    RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
+    SmallVector<Value> &tiledValues,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
+
 /// Clones the operation and updates the destination if the operation
 /// implements the `DestinationStyleOpInterface`.
 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
@@ -288,25 +307,6 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
 
-/// A function that allows returning additional yielded values during
-/// `yieldTiledValuesAndReplace`.
-/// - `ivs` induction variable for the loop.
-/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
-/// - `tiledValues` the tiled values to return. Must be of same size as
-///   `newbbArgs`, each element of this array is inserted into the corresponding
-///   element in `newbbArgs`.
-/// - `resultOffsets` is of the same size as `tiledValues` and represents
-///   the offsets to use when inserting corresponding element from `tiledValues`
-///   into the element from `newBbArgs`.
-/// - `resultSizes` is of the same size as `tiledValues` and represents
-///   the size of the corresponding element from `tiledValues` inserted into
-///   the element from `newBbArgs`.
-using YieldTiledValuesFn = llvm::function_ref<LogicalResult(
-    RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
-    SmallVector<Value> &tiledValues,
-    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
-    SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
-
 /// Append the specified additional `newInitOperands` operands to the
 /// loops existing `init` operands (or similar), and replace `loopOp` with
 /// the new loop that has the additional init operands. The loop body of
@@ -441,8 +441,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
     LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
     ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
-  return TypeSwitch<LoopLikeOpInterface, FailureOr<LoopLikeOpInterface>>(
-             loopLikeOp)
+  return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
+             loopLikeOp.getOperation())
       .Case<scf::ForOp, scf::ForallOp>(
           [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
             return yieldTiledValuesAndReplaceLoop(
@@ -460,7 +460,7 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
 /// the additional values to yield form the innermost loop.
 static LogicalResult addInitOperandsToLoopNest(
     RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
-    ValueRange newInitValues, const YieldTiledValuesFn &getNewTiledYieldsFn) {
+    ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
   SmallVector<scf::ForOp> newLoops;
   if (loops.empty())
     return success();



More information about the flang-commits mailing list