[Mlir-commits] [mlir] [mlir][TilingInterface] Handle multi operand consumer fusion. (PR #145193)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 25 11:19:25 PDT 2025


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/145193

>From 35bc469b086045f48c7e3b10895f2f9bd7df4ca9 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 18 Jun 2025 17:05:43 -0700
Subject: [PATCH 1/2] [mlir][TilingInterface] Handle multi operand consumer
 fusion.

For consumer fusion cases of this form

```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {

  tensor.parallel_insert_slice ... into %arg0
  tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#0, %0#1)
```

the current consumer fusion that handles one slice at a time cannot
fuse the consumer into the loop, since fusing along one slice will
create and SSA violation on the other use from the `scf.forall`. The
solution is to allow consumer fusion to allow considering multiple
slices at once. This PR changes the `TilingInterface` methods related
to consumer fusion, i.e.

- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`

to allow fusion while considering multiple operands. It is upto the
`TilingInterface` implementation to return an error if a list of tiles
of the operands cannot result in a consistent implementation of the
tiled operation.

The Linalg operation implementation of `TilingInterface` has been
modified to account for these changes and allow cases where operand
tiles that can result in a consistent tiling implementation are handled.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../SCF/Transforms/TileUsingInterface.h       |  22 +-
 .../Dialect/Tensor/Transforms/Transforms.h    |  14 +-
 mlir/include/mlir/IR/OpDefinition.h           |   2 +-
 .../mlir/Interfaces/TilingInterface.td        |  55 ++--
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 182 +++++++----
 .../SCF/Transforms/TileUsingInterface.cpp     | 209 +++++++++----
 .../SwapExtractSliceWithProducerPatterns.cpp  |  47 ++-
 .../transform-op-fuse-into-containing.mlir    |   1 +
 .../tile-and-fuse-consumer.mlir               | 293 +++++++++++++++++-
 .../TestTilingInterfaceTransformOps.cpp       |  47 +--
 .../TestTilingInterfaceTransformOps.td        |   7 +-
 11 files changed, 672 insertions(+), 207 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 7b43aa43c7517..3205da6e448fc 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -313,19 +313,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
                                      TilingInterface consumer,
                                      const SCFTileAndFuseOptions &options);
 
-/// Fuse the consumer of the source of `candidateSliceOp` by computing the
-/// required slice of the consumer in-place.  Note that the method
-/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
-/// value but does not delete the slice operation.
+/// Fuse the consumer `candidateSlices` by computing the required slice of the
+/// consumer in-place. All the entries of `candidateSlices` are expected to map
+/// to the same consumer. The method returns an error if the consumer cannot be
+/// tiled in a manner that is consistent for all the passed slices. Note that
+/// the method replaces the uses of `candidateSlices` with the tiled and fused
+/// consumer value but does not delete the slice operations.
 struct SCFFuseConsumerOfSliceResult {
-  OpOperand *origConsumerOperand; // Original untiled consumer's operand.
-  OpOperand
-      *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
+  // Original untiled consumer operands.
+  SmallVector<OpOperand *> origConsumerOperands;
+  // Tiled and fused consumer operands.
+  SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
   SmallVector<Operation *> tiledOps;
 };
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
-                           MutableArrayRef<LoopLikeOpInterface> loops);
+tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
+                            ArrayRef<Operation *> candidateSlices,
+                            MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 18981337742eb..9860b06348407 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -31,12 +31,16 @@ namespace tensor {
 FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
 
-/// Method to swap an `tensor.insert_slice` with its consumer when the
-/// consumer implements the `TilingInterface`.
+/// Method to swap an `tensor.insert_slice`s with its consumer when the
+/// consumer implements the `TilingInterface`. The size of `sliceOps` and
+/// `consumerOperands` is expected to be the same. Every entry in
+/// `consumerOperands` represents a use of the the corresponding
+/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is
+/// expected to be uses in the same consumer.
 FailureOr<TilingResult>
-replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
-                                    OffsetSizeAndStrideOpInterface sliceOp,
-                                    OpOperand &consumerOp);
+replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
+                                     ArrayRef<tensor::InsertSliceOp> sliceOps,
+                                     ArrayRef<OpOperand *> consumerOperands);
 
 //===----------------------------------------------------------------------===//
 // Populate functions.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 31f54413a5ff0..663c256c848df 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
   using PointerUnion<Attribute, Value>::PointerUnion;
 
 public:
-  void dump() const { llvm::errs() << *this << "\n"; }
+  LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }
 
   MLIRContext *getContext() const {
     PointerUnion pu = *this;
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 0de37338c95e4..e64d870048e0e 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the tiled implementation of an operation that uses
-          exactly a tile of the given operand.
+          exactly tiles of the given operands.
 
           This method is required to allow operations to be "tiled and fused"
-          with an (already tiled) producer. Given a tile of the producer, this
-          method generates the tile of the consumer that uses exactly this
-          produced tile. In some sense it is the "reverse" of
+          with an (already tiled) producer. Given tiles of the producer, this
+          method generates the tile of the consumer that uses exactly these
+          produced tiles. In some sense it is the "reverse" of
           `generateResultTileValue`.
-          - `operandNumber` is the result of the producer used by the consumer.
-          - `offsets` is the offset of the slice of the producer result used by
-            the tiled implementation of the consumer.
-          - `sizes` is the size of the slice of the producer result used by the
+          - `operandNumbers` is the list of operands whose tiles are "producers".
+          - `allOffsets` is the offset of the slice of the producer used by the
+            tiled implementation of the consumer.
+          - `allSizes` is the size of the slice of the producer used by the
             consumer.
-          If it is illegal to fuse with a producer along the given operand for
+          If it is illegal to fuse with a producer along the given operand tiles for
           an operation, the implementation should return a failure.
         }],
         /*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
-        /*methodName=*/"getTiledImplementationFromOperandTile",
+        /*methodName=*/"getTiledImplementationFromOperandTiles",
         /*args=*/(ins
           "::mlir::OpBuilder &":$b,
-          "unsigned":$operandNumber,
-          "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
-          "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
+          "::mlir::ArrayRef<unsigned>":$operandNumbers,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           tile of the operand.
 
           This method is required to allow operations to be "tiled and fused"
-          with an (already tiled) producer. Given a tile of an operand,
-          returns the tile of the iteration space that uses this tile.
-          - `operandNumber` is the result of the producer used by the consumer.
-          - `offsets` is the offset of the slice of the producer result used by
+          with an (already tiled) producer. Given tiles of operands,
+          returns the tile of the iteration space that uses these tiles.
+          - `operandNumbers` is the list of operands whose tiles are "produced"
+            by the producer(s).
+          - `allOffsets` is the offset of the slice of the producers used by
             the tiled implementation of the consumer.
-          - `sizes` is the size of the slice of the producer result used by the
+          - `allSizes` is the size of the slice of the producers used by the
             consumer.
-          If it is illegal to fuse with a producer along the given operand for
-          an operation, or if this mapping cannot be computed, the
-          implementation should return a failure.
+          If it is illegal to fuse with the producer slices for an operation,
+          or if this mapping cannot be computed, the implementation should
+          return a failure.
 
           Note that unlike the "tile consumer and fuse producer" case, the
           "tile producer and fuse consumer" requires an additional method to get
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           transformation. It does not provide guarantees on whether such a
           transformation is profitable.
 
-          For most cases `getTiledImplementationFromOperandTile` could be a
-          implemented using `getIterationDomainTileFromOperandTile` +
+          For most cases `getTiledImplementationFromOperandTiles` could be a
+          implemented using `getIterationDomainTileFromOperandTiles` +
           `getTiledImplementation` methods.
         }],
         /*retType=*/"::llvm::LogicalResult",
-        /*methodName=*/"getIterationDomainTileFromOperandTile",
+        /*methodName=*/"getIterationDomainTileFromOperandTiles",
         /*args=*/(ins
           "::mlir::OpBuilder &":$b,
-          "unsigned":$operandNumber,
-          "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
-          "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+          "::mlir::ArrayRef<unsigned>":$operandNumbers,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
           "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
           "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
         /*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 19d484a3bb701..db995397ced67 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -148,55 +148,80 @@ struct LinalgOpTilingInterface
   /// Utility to fetch the offsets and sizes when applied as per the indexing
   /// map of the linalg op. This helps in fusing the linalg op as a consumer of
   /// a given slice op.
-  void
-  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
-                         ArrayRef<OpFoldResult> offsets,
-                         ArrayRef<OpFoldResult> sizes,
-                         SmallVectorImpl<OpFoldResult> &mappedOffsets,
-                         SmallVectorImpl<OpFoldResult> &mappedSizes) const {
-    unsigned numLoops = linalgOp.getNumLoops();
-    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
-    mappedOffsets.resize(numLoops);
-    mappedSizes.resize(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
-        mappedOffsets[index] = value.offset;
-        mappedSizes[index] = value.size;
+  static LogicalResult
+  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
+                         ArrayRef<AffineMap> indexingMaps,
+                         ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+                         ArrayRef<SmallVector<OpFoldResult>> allSizes,
+                         SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
+                         SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
+    DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
+
+    for (auto [indexingMap, offsets, sizes] :
+         llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
+      for (auto [resultExpr, offset, size] :
+           llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
+        auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
+        if (!dimExpr)
+          continue;
+        unsigned position = dimExpr.getPosition();
+        auto it = mappedOffsets.find(position);
+        if (it != mappedOffsets.end()) {
+          OpFoldResult seenOffset = it->second;
+          OpFoldResult seenSize = mappedSizes.lookup(position);
+          if (seenOffset != offset || seenSize != size) {
+            return linalgOp->emitOpError(
+                "inconsistent iteration space mapping from offsets/sizes of "
+                "operands/results");
+          }
+        } else {
+          mappedOffsets[position] = offset;
+          mappedSizes[position] = size;
+        }
       }
     }
-    for (const auto &&[index, value] :
-         llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
-      mappedOffsets[dimPosition] = offsets[index];
-      mappedSizes[dimPosition] = sizes[index];
+
+    // Aggregate from the given operand offsets and sizes, or default to
+    // iteration space values.
+    SmallVector<Range> iterationDomain =
+        cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
+    mappedOffsetsVec.resize(iterationDomain.size());
+    mappedSizesVec.resize(iterationDomain.size());
+    for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
+      auto it = mappedOffsets.find(index);
+      if (it != mappedOffsets.end()) {
+        mappedOffsetsVec[index] = it->second;
+        mappedSizesVec[index] = mappedSizes.lookup(index);
+        continue;
+      }
+      mappedOffsetsVec[index] = domain.offset;
+      mappedSizesVec[index] = domain.size;
     }
+    return success();
   }
 
   /// Method to return the position of the result tile computed by the tiled
   /// operation.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
       SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
 
-    // Check that the indexing map used for the operand is a projected
-    // permutation. This could be relaxed with a more general approach that can
-    // map the offsets and sizes from the operand to iteration space tiles
-    // (filling in full extent for dimensions not used to access the result).
-    AffineMap indexingMap =
-        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
-    if (!indexingMap.isProjectedPermutation()) {
-      return op->emitError()
-             << "unhandled get iter domain position when operand is not "
-                "accessed using a permuted projection";
+    std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
+        iterationSpaceSizes;
+    SmallVector<AffineMap> indexingMaps =
+        llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+          OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+          return linalgOp.getMatchingIndexingMap(&opOperand);
+        });
+    if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
+                                      allSizes, iterDomainOffsets,
+                                      iterDomainSizes))) {
+      return failure();
     }
-
-    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           iterDomainOffsets, iterDomainSizes);
     return success();
   }
 
@@ -247,8 +272,13 @@ struct LinalgOpTilingInterface
           "accessed using a permuted projection");
     }
 
-    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           iterDomainOffsets, iterDomainSizes);
+    SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
+    SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
+    auto status =
+        getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
+                               {allSizes}, iterDomainOffsets, iterDomainSizes);
+    (void)status;
+    assert(succeeded(status) && "unexpected error in offset calculation");
     return success();
   }
 
@@ -279,12 +309,13 @@ struct LinalgOpTilingInterface
 
   /// Method to generate the tiled implementation of an operation from the tile
   /// of the operand.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
     SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, operandNumber, offsets, sizes, mappedOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
             mappedSizes))) {
       return failure();
     }
@@ -837,13 +868,17 @@ struct PackOpTiling
   /// Method to return the position of iteration domain tile computed by the
   /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
   /// `resultSizes` only cover outer dimensions.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &resultOffsets,
       SmallVectorImpl<OpFoldResult> &resultSizes) const {
-    if (operandNumber != 0)
-      return failure();
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+      return op->emitOpError("unsupporeted operands for consumer fusion");
+
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
 
     auto packOp = cast<PackOp>(op);
     // It is not trivial to infer dest tile from source tile if `packOp` has
@@ -904,11 +939,15 @@ struct PackOpTiling
   }
 
   /// Method to return the tiled implementation of tensor.pack as a consumer.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
-    if (operandNumber != 0)
-      return failure();
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+      return op->emitOpError("unhandled operands for consumer fusion");
+
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
 
     auto packOp = cast<PackOp>(op);
     Location loc = packOp.getLoc();
@@ -923,8 +962,8 @@ struct PackOpTiling
     tiledOperands.push_back(sourceSlice);
 
     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
             outerDimSizes)))
       return failure();
 
@@ -1182,12 +1221,20 @@ struct UnPackOpTiling
 
   /// Method to return the position of iteration domain tile computed by the
   /// tiled operation.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &resultOffsets,
       SmallVectorImpl<OpFoldResult> &resultSizes) const {
+    if (operandNumbers.size() != 1) {
+      return op->emitOpError("unable to handle multiple operands");
+    }
     auto unPackOp = cast<UnPackOp>(op);
+    unsigned operandNumber = operandNumbers[0];
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
+
     // If the operand tile is the dest, then no adjustment is needed.
     if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
       resultOffsets = llvm::to_vector(offsets);
@@ -1241,10 +1288,17 @@ struct UnPackOpTiling
   }
 
   /// Method to return the tiled implementation of tensor.unpack as a consumer.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
+      return op->emitOpError("unhandled operands for consumer fusion");
+    }
     auto unPackOp = cast<UnPackOp>(op);
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
+
     // tensor.unpack op is fusible (as a consumer) only if inner dims are not
     // tiled.
     int64_t numTiles = unPackOp.getInnerDimsPos().size();
@@ -1259,8 +1313,8 @@ struct UnPackOpTiling
     // Fetch offset/size for creating the slice of the dest operand of
     // unpack op.
     SmallVector<OpFoldResult> outputOffsets, outputSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
             outputSizes)))
       return failure();
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index ddcae8481a5b4..75f2f8fa50894 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -2047,53 +2047,118 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
 
 /// A utility to fetch an untiled consumer of
 /// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
-                            MutableArrayRef<LoopLikeOpInterface> loops) {
+static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
+    RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
   assert(!loops.empty() && "unexpected empty loops");
-  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
-  } else if (auto parallelInsertSlice =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
-  } else {
-    return failure();
+  assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
+  SmallVector<OpOperand *> fusedOperands;
+  for (auto sliceOp : sliceOps) {
+    FailureOr<OpOperand *> fusedOperand =
+        TypeSwitch<Operation *, FailureOr<OpOperand *>>(sliceOp)
+            .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+                [&](auto op) {
+                  return getUntiledConsumerFromSlice(rewriter, op, loops);
+                })
+            .Default([](Operation *op) {
+              return op->emitOpError("unhandled slice type");
+            });
+    if (failed(fusedOperand)) {
+      return failure();
+    }
+    if (!fusedOperands.empty() &&
+        fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
+      return fusedOperands.front()->getOwner()->emitOpError(
+          "all candidate slices must be to the same consumer");
+    }
+    fusedOperands.push_back(fusedOperand.value());
+  }
+  return fusedOperands;
+}
+
+template <typename InsertSliceOpTy>
+static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
+                                                InsertSliceOpTy sliceOp);
+
+template <>
+tensor::InsertSliceOp
+cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
+                                          tensor::InsertSliceOp insertSliceOp) {
+  return cast<tensor::InsertSliceOp>(
+      rewriter.clone(*insertSliceOp.getOperation()));
+}
+
+template <>
+tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
+  return rewriter.create<tensor::InsertSliceOp>(
+      insertSliceOp->getLoc(), insertSliceOp.getSource(),
+      insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
+      insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+}
+
+static SmallVector<tensor::InsertSliceOp>
+cloneAsInsertSlices(RewriterBase &rewriter,
+                    ArrayRef<Operation *> candidateSlices) {
+  assert(!candidateSlices.empty() &&
+         "unexpected empty list of slices to clone");
+  SmallVector<tensor::InsertSliceOp> clonedSlices;
+  for (auto sliceOp : candidateSlices) {
+    TypeSwitch<Operation *>(sliceOp)
+        .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+            [&](auto op) {
+              auto clonedOp = cloneAsInsertSlice(rewriter, op);
+              clonedSlices.push_back(clonedOp);
+            })
+        .Default([&](Operation *op) {
+          // Assert here assuming this has already been checked.
+          assert(0 && "unexpected slice type while cloning as insert slice");
+        });
   }
+  return clonedSlices;
 }
 
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(
-    RewriterBase &rewriter, Operation *candidateSliceOp,
+mlir::scf::tileAndFuseConsumerOfSlices(
+    RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
     MutableArrayRef<LoopLikeOpInterface> loops) {
+  if (candidateSlices.empty()) {
+    return emitError(rewriter.getUnknownLoc(),
+                     "no candidate slices provided for consumer fusion");
+  }
   // Return if `loops` is empty, return an error for now. Caller is expected
   // to handle this case.
   if (loops.empty()) {
-    return candidateSliceOp->emitOpError(
+    return candidateSlices.front()->emitOpError(
         "cannot call tile and fuse consumer with an empty loop nest");
   }
-  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
-          candidateSliceOp))
-    return failure();
+
+  if (!(llvm::all_of(
+            candidateSlices,
+            [](Operation *op) { return isa<tensor::InsertSliceOp>(op); }) ||
+        llvm::all_of(candidateSlices, [](Operation *op) {
+          return isa<tensor::ParallelInsertSliceOp>(op);
+        }))) {
+    return candidateSlices.front()->emitOpError(
+        "candidates slices need to be all `tensor.extract_slice`s or "
+        "`tensor.parallel_insert_slice`s");
+  }
 
   // 1. Get the consumer of scf.for for the result yielded by
   // tensor.insert_slice/parallel_insert_slice.
-  FailureOr<OpOperand *> maybeConsumerOpOperand =
-      getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
-  if (failed(maybeConsumerOpOperand)) {
-    return rewriter.notifyMatchFailure(candidateSliceOp,
-                                       "could not fetch consumer to fuse");
-  }
-  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
-  Operation *consumerOp = consumerOpOperand->getOwner();
-  unsigned operandNumber = consumerOpOperand->getOperandNumber();
-  unsigned resultNumber = 0;
-  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
-    resultNumber = producerResult.getResultNumber();
-  } else {
-    return rewriter.notifyMatchFailure(
-        consumerOp, "consumer op's operand doesn't seem to be an OpResult");
+  SmallVector<OpOperand *> consumerOpOperands;
+  Operation *consumerOp;
+  {
+    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
+        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+    if (failed(maybeConsumerOpOperand)) {
+      return rewriter.notifyMatchFailure(candidateSlices.front(),
+                                         "could not fetch consumer to fuse");
+    }
+    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
+    consumerOp = consumerOpOperands.front()->getOwner();
   }
 
   LoopLikeOpInterface outerMostLoop = loops.front();
@@ -2113,16 +2178,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
   if (!dstOp)
     return rewriter.notifyMatchFailure(consumerOp,
                                        "consumer op is not DPS operation");
-  SmallVector<Value> dpsInits =
-      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
+  if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
+        return dstOp.isDpsInit(opOperand);
+      })) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  SmallVector<Value> newInits = dpsInits;
-
-  Location loc = outerMostLoop->getLoc();
+  SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
 
   // 3. Move the whole loop structure right before firstUserOfLoop, the
   // dominance should be already ensured by `checkAssumptionForLoop`.
@@ -2137,43 +2200,52 @@ mlir::scf::tileAndFuseConsumerOfSlice(
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
-  tensor::InsertSliceOp clonedInsertSliceOp;
   if (auto sliceOp =
-          dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+          dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
     auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
     rewriter.setInsertionPoint(newForallOp.getTerminator());
-    clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
   } else {
-    rewriter.setInsertionPoint(candidateSliceOp);
-    clonedInsertSliceOp =
-        cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
+    rewriter.setInsertionPoint(candidateSlices.front());
   }
+  // 5.a. Clone all the candidate slices as equivalent insert slice ops.
+  SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
+      cloneAsInsertSlices(rewriter, candidateSlices);
 
-  // 5.a. Clone consumer op.
+  // 5.b. Clone consumer op.
   auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
+  SmallVector<unsigned> operandNumbers =
+      llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
+        return opOperand->getOperandNumber();
+      });
+  SmallVector<OpOperand *> clonedOpFusedOperandsList =
+      llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
+        return &clonedConsumerOp->getOpOperand(operandNum);
+      });
 
-  // 5.b. Replace all uses of the loop result with the result of the cloned
+  // 5.c. Replace all uses of the loop result with the result of the cloned
   // tensor.insert_slice.
-  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
-    operandToReplace.set(clonedInsertSliceOp.getResult());
+    for (auto [operandToReplace, clonedSliceOp] :
+         llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
+      operandToReplace->set(clonedSliceOp.getResult());
+    }
   });
 
   // 6. Perform tiling of the cloned consumer and replace the operand at
   // `operandNumber` with the source of the cloned tensor.insert_slice op.
-  auto ossSliceOp =
-      cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
   FailureOr<TilingResult> tileAndFuseResult =
-      tensor::replaceInsertSliceWithTiledConsumer(
-          rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+      tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
+                                                   clonedOpFusedOperandsList);
   if (failed(tileAndFuseResult)) {
     return failure();
   }
+
   auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
-  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
-                              clonedInsertSliceOp.getSource());
+  for (auto [operandNum, clonedSliceOp] :
+       llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
+    rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
+                                clonedSliceOp.getSource());
+  }
 
   // 7. Reconstruct [nested] loop with new inits.
   YieldTiledValuesFn newYieldValuesFn =
@@ -2185,14 +2257,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
     // 8. Set inner insertPoint right before tiled consumer op.
     innerRewriter.setInsertionPoint(tiledConsumerOp);
 
-    SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+    SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
+    for (auto candidateSliceOp : clonedInsertSlices) {
+      SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+      SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
 
-    // 9. Check all insert stride is 1.
-    if (!llvm::all_of(strides, isOneInteger)) {
-      return rewriter.notifyMatchFailure(
-          candidateSliceOp, "containingOp's result yield with stride");
+      // 9. Check all insert stride is 1.
+      allOffsets.emplace_back(std::move(offsets));
+      allSizes.emplace_back(std::move(sizes));
     }
 
     // 10. Try to get iter domain position from input position. Use
@@ -2202,8 +2274,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(
     // tiledConsumerOp could lead to some chained unnecessary extra index
     // computation.
     SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-    if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-            rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+    if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
+            rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
             iterDomainSizes))) {
       return rewriter.notifyMatchFailure(
           clonedConsumerOp,
@@ -2279,10 +2351,13 @@ mlir::scf::tileAndFuseConsumerOfSlice(
   // 16. Need to erase the old scf loop and the cloned consumer op.
   rewriter.eraseOp(clonedConsumerOp);
 
+  SmallVector<OpOperand *> tiledAndFusedOpOperands =
+      llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
+        return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
+      });
   return scf::SCFFuseConsumerOfSliceResult{
-      consumerOpOperand,
-      &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
-      tileAndFuseResult->tiledOps};
+      std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
+      std::move(tileAndFuseResult->tiledOps)};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 6f33f9b55ceb6..10831e790d739 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -39,21 +39,48 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
   return *tiledResult;
 }
 
-FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
-    OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
-    OpOperand &consumer) {
-  auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
+FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
+    OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
+    ArrayRef<OpOperand *> consumerOperands) {
+  if (sliceOps.empty()) {
+    return emitError(builder.getUnknownLoc(),
+                     "expected candidate slices list to be non-empty");
+  }
+  if (sliceOps.size() != consumerOperands.size()) {
+    return sliceOps.front()->emitOpError(
+        "expected as many operands as the number of slices passed");
+  }
+  auto consumerOp =
+      dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
   if (!consumerOp)
     return failure();
+  for (auto opOperand : consumerOperands.drop_front()) {
+    if (opOperand->getOwner() != consumerOp) {
+      return consumerOp->emitOpError(
+          "expected all consumer operands to be from the same operation");
+    }
+  }
 
-  // `TilingInterface` currently only supports strides being 1.
-  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
-    return failure();
+  auto consumerOperandNums = llvm::map_to_vector(
+      consumerOperands, [](OpOperand *opOperand) -> unsigned {
+        return opOperand->getOperandNumber();
+      });
+  SmallVector<SmallVector<OpFoldResult>> allOffsets;
+  SmallVector<SmallVector<OpFoldResult>> allSizes;
+  for (auto sliceOp : sliceOps) {
+
+    // `TilingInterface` currently only supports strides being 1.
+    if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
+      return failure();
 
+    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
+    allOffsets.emplace_back(std::move(offsets));
+    allSizes.emplace_back(std::move(sizes));
+  }
   FailureOr<TilingResult> tiledResult =
-      consumerOp.getTiledImplementationFromOperandTile(
-          builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
-          sliceOp.getMixedSizes());
+      consumerOp.getTiledImplementationFromOperandTiles(
+          builder, consumerOperandNums, allOffsets, allSizes);
   if (failed(tiledResult))
     return failure();
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 572a2ae70e0a4..5bdb5073ee865 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -653,6 +653,7 @@ module {
       %5 = affine.min #map2(%i)[%d0, %idx]
       %6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
 
+      // CHECK: linalg.generic
       // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
       // CHECK: %[[T2:.*]] = linalg.generic {{.*}}
       %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 77e52946b830f..f58d294d9611a 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
 
 #map = affine_map<(d0) -> (d0)>
 module {
@@ -620,3 +620,294 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %loop:2 = scf.forall (%iv0) =  (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+    %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic:2 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	%1 = arith.addf %b0, %b2 : f32
+	linalg.yield %0, %1 : f32, f32
+    } -> (tensor<?xf32>, tensor<?xf32>)
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+      tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+    }	
+  }
+  %empty = tensor.empty(%dim0) : tensor<?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?xf32>
+  return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion1(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0
+//       CHECK:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//       CHECK:     %[[TILESIZE:.+]] = affine.min
+//   CHECK-DAG:     %[[GENERIC:.+]]:2 = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC]]#0, %[[GENERIC]]#1 :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:   return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// Check that when the given operand tiles are incosistent, tiling fails.
+
+func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %loop:2 = scf.forall (%iv0) =  (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+    %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	linalg.yield %0 : f32
+    } -> tensor<?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+	%0 = arith.addf %b0, %b1 : f32
+	linalg.yield %0: f32
+    } -> tensor<?xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+    }	
+  }
+  %empty = tensor.empty(%dim0) : tensor<?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?xf32>
+  return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion2(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0
+//       CHECK:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//       CHECK:     %[[TILESIZE:.+]] = affine.min
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//       CHECK:     %[[GENERIC1:.+]] = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC0]], %[[GENERIC1]] :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:   return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
+    %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %loop:2 = scf.forall (%iv0, %iv1) =  (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+      shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) {
+    %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+        : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	      iterator_types = ["parallel", "reduction"]}
+	      ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+      	%0 = arith.addf %b0, %b1 : f32
+	      linalg.yield %0: f32
+    } -> tensor<?xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32>
+    }
+  }
+  %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @multi_slice_fusion_with_broadcast(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1
+//   CHECK-DAG:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//   CHECK-DAG:     %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]])
+//   CHECK-DAG:     %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]])
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//       CHECK:     %[[GENERIC1:.+]] = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC0]], %[[GENERIC1]] :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
+//       CHECK:   return %[[RESULT]]#2
+
+// -----
+
+func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
+    %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %loop:2 = scf.forall (%iv0, %iv1) =  (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+      shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+    %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+        : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+      	%0 = arith.addf %b0, %b1 : f32
+	      linalg.yield %0: f32
+    } -> tensor<?x?xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+    }
+  }
+  %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+  // expected-error @below {{inconsistent iteration space mapping from offsets/sizes of operands/results}}
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 9971f0cde4ed2..3c810f6f4accf 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -168,29 +168,30 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
 
 /// Apply fusing of consumer transformation to all payload ops and store both
 /// the original consumer operation as well as the fused consumer operation.
-template <typename Range>
 static LogicalResult applyFuseConsumer(
-    RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
-    MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse,
-    TransformResults &transformResults) {
+    RewriterBase &rewriter, Operation *transformOp,
+    ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
+    uint32_t numConsumerToFuse, TransformResults &transformResults) {
   SmallVector<Operation *> originalConsumerOps;
   SmallVector<Operation *> fusedConsumerOps;
 
-  for (Operation *target : payloadOps) {
-    rewriter.setInsertionPoint(target);
+  rewriter.setInsertionPoint(slices.front());
 
-    while (numConsumerToFuse--) {
-      FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
-          scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
+  while (numConsumerToFuse--) {
+    FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+        scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
 
-      if (failed(fuseConsumerResults))
-        return failure();
+    if (failed(fuseConsumerResults))
+      return failure();
 
-      // Report back the relevant handles to the transform op.
-      originalConsumerOps.push_back(
-          fuseConsumerResults->origConsumerOperand->getOwner());
-      fusedConsumerOps.push_back(
-          fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+    // Report back the relevant handles to the transform op.
+    for (OpOperand *origConsumerOperand :
+         fuseConsumerResults->origConsumerOperands) {
+      originalConsumerOps.push_back(origConsumerOperand->getOwner());
+    }
+    for (OpOperand *tiledAndFusedConsumerOperand :
+         fuseConsumerResults->tiledAndFusedConsumerOperands) {
+      fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
     }
   }
 
@@ -203,6 +204,12 @@ DiagnosedSilenceableFailure
 transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
                                      TransformResults &transformResults,
                                      TransformState &state) {
+  SmallVector<Operation *> slices;
+  for (auto op : getTargets()) {
+    auto sliceOp = *state.getPayloadOps(op).begin();
+    slices.push_back(sliceOp);
+  }
+
   SmallVector<LoopLikeOpInterface> loops;
   for (auto op : llvm::reverse(getLoops())) {
     auto loopLikeOp =
@@ -212,16 +219,16 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
     }
     loops.push_back(loopLikeOp);
   }
-  LogicalResult result = applyFuseConsumer(
-      rewriter, getOperation(), state.getPayloadOps(getTarget()), loops,
-      getNumConsumerToFuse(), transformResults);
+  LogicalResult result =
+      applyFuseConsumer(rewriter, getOperation(), slices, loops,
+                        getNumConsumerToFuse(), transformResults);
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
 }
 
 void transform::TestFuseConsumerOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  consumesHandle(getTargetMutable(), effects);
+  consumesHandle(getTargetsMutable(), effects);
   consumesHandle(getLoopsMutable(), effects);
   producesHandle(getOperation()->getOpResults(), effects);
   modifiesPayload(effects);
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 98f7145c99cb1..3c09082e192ea 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -50,7 +50,8 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
 }
 
 def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
-       [DeclareOpInterfaceMethods<TransformOpInterface>,
+       [AttrSizedOperandSegments,
+        DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
         ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
@@ -59,14 +60,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
   }];
 
   let arguments = (ins 
-      TransformHandleTypeInterface:$target,
+      Variadic<TransformHandleTypeInterface>:$targets,
       Variadic<TransformHandleTypeInterface>:$loops,
       DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
   let results = (outs TransformHandleTypeInterface:$consumer,
                       TransformHandleTypeInterface:$fused_consumer);
 
   let assemblyFormat = [{
-    $target `in` `(` $loops `)`
+    $targets `in` `(` $loops `)`
     (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? 
     attr-dict `:` functional-type(operands, results)
   }];

>From 75d320c48fbe8572c67cecb10bf4de526619992c Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Tue, 24 Jun 2025 22:29:04 -0700
Subject: [PATCH 2/2] Address comments (round 1).

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../Dialect/Tensor/Transforms/Transforms.h    |  2 +-
 .../mlir/Interfaces/TilingInterface.td        |  2 +-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 31 ++++++++++++-----
 .../SCF/Transforms/TileUsingInterface.cpp     | 33 +++++++++++--------
 .../SwapExtractSliceWithProducerPatterns.cpp  | 22 +++++++++----
 .../tile-and-fuse-consumer.mlir               |  4 +--
 .../TestTilingInterfaceTransformOps.cpp       |  5 ++-
 7 files changed, 66 insertions(+), 33 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 9860b06348407..87deef9ca7466 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -31,7 +31,7 @@ namespace tensor {
 FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
 
-/// Method to swap an `tensor.insert_slice`s with its consumer when the
+/// Method to swap `tensor.insert_slice`s with their consumers when the
 /// consumer implements the `TilingInterface`. The size of `sliceOps` and
 /// `consumerOperands` is expected to be the same. Every entry in
 /// `consumerOperands` represents a use of the the corresponding
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index e64d870048e0e..0c0fc88aec95a 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -202,7 +202,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the tiled implementation of an operation that uses
-          exactly tiles of the given operands.
+          the exact tiles of the given operands.
 
           This method is required to allow operations to be "tiled and fused"
           with an (already tiled) producer. Given tiles of the producer, this
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index db995397ced67..513cecef29b61 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -22,8 +22,11 @@
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/Support/Debug.h"
 #include <optional>
 
+#define DEBUG_TYPE "linalg-tiling-interface-impl"
+
 using namespace mlir;
 using namespace mlir::linalg;
 
@@ -170,9 +173,11 @@ struct LinalgOpTilingInterface
           OpFoldResult seenOffset = it->second;
           OpFoldResult seenSize = mappedSizes.lookup(position);
           if (seenOffset != offset || seenSize != size) {
-            return linalgOp->emitOpError(
-                "inconsistent iteration space mapping from offsets/sizes of "
-                "operands/results");
+            LLVM_DEBUG({
+              llvm::dbgs() << "inconsistent iteration space mapping from "
+                              "offsets/sizes of operands/results";
+            });
+            return failure();
           }
         } else {
           mappedOffsets[position] = offset;
@@ -874,8 +879,11 @@ struct PackOpTiling
       ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &resultOffsets,
       SmallVectorImpl<OpFoldResult> &resultSizes) const {
-    if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
-      return op->emitOpError("unsupporeted operands for consumer fusion");
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
+      LLVM_DEBUG(
+          { llvm::dbgs() << "unsupported operands for consumer fusion"; });
+      return failure();
+    }
 
     ArrayRef<OpFoldResult> offsets(allOffsets[0]);
     ArrayRef<OpFoldResult> sizes(allSizes[0]);
@@ -943,8 +951,11 @@ struct PackOpTiling
       Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
       ArrayRef<SmallVector<OpFoldResult>> allOffsets,
       ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
-    if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
-      return op->emitOpError("unhandled operands for consumer fusion");
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
+      LLVM_DEBUG(
+          { llvm ::dbgs() << "unhandled operands for consumer fusion"; });
+      return failure();
+    }
 
     ArrayRef<OpFoldResult> offsets(allOffsets[0]);
     ArrayRef<OpFoldResult> sizes(allSizes[0]);
@@ -1228,7 +1239,8 @@ struct UnPackOpTiling
       SmallVectorImpl<OpFoldResult> &resultOffsets,
       SmallVectorImpl<OpFoldResult> &resultSizes) const {
     if (operandNumbers.size() != 1) {
-      return op->emitOpError("unable to handle multiple operands");
+      LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; });
+      return failure();
     }
     auto unPackOp = cast<UnPackOp>(op);
     unsigned operandNumber = operandNumbers[0];
@@ -1293,7 +1305,8 @@ struct UnPackOpTiling
       ArrayRef<SmallVector<OpFoldResult>> allOffsets,
       ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
     if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
-      return op->emitOpError("unhandled operands for consumer fusion");
+      LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; });
+      return failure();
     }
     auto unPackOp = cast<UnPackOp>(op);
     ArrayRef<OpFoldResult> offsets(allOffsets[0]);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 75f2f8fa50894..995120ad8680e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -2060,15 +2060,16 @@ static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
                 [&](auto op) {
                   return getUntiledConsumerFromSlice(rewriter, op, loops);
                 })
-            .Default([](Operation *op) {
-              return op->emitOpError("unhandled slice type");
+            .Default([&](Operation *op) {
+              return rewriter.notifyMatchFailure(op, "unhandled slice type");
             });
     if (failed(fusedOperand)) {
       return failure();
     }
     if (!fusedOperands.empty() &&
         fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
-      return fusedOperands.front()->getOwner()->emitOpError(
+      return rewriter.notifyMatchFailure(
+          fusedOperand.value()->getOwner(),
           "all candidate slices must be to the same consumer");
     }
     fusedOperands.push_back(fusedOperand.value());
@@ -2125,23 +2126,23 @@ mlir::scf::tileAndFuseConsumerOfSlices(
     RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
     MutableArrayRef<LoopLikeOpInterface> loops) {
   if (candidateSlices.empty()) {
-    return emitError(rewriter.getUnknownLoc(),
-                     "no candidate slices provided for consumer fusion");
+    return rewriter.notifyMatchFailure(
+        rewriter.getUnknownLoc(),
+        "no candidate slices provided for consumer fusion");
   }
   // Return if `loops` is empty, return an error for now. Caller is expected
   // to handle this case.
   if (loops.empty()) {
-    return candidateSlices.front()->emitOpError(
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
         "cannot call tile and fuse consumer with an empty loop nest");
   }
 
-  if (!(llvm::all_of(
-            candidateSlices,
-            [](Operation *op) { return isa<tensor::InsertSliceOp>(op); }) ||
-        llvm::all_of(candidateSlices, [](Operation *op) {
-          return isa<tensor::ParallelInsertSliceOp>(op);
-        }))) {
-    return candidateSlices.front()->emitOpError(
+  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
+        llvm::all_of(candidateSlices,
+                     llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
         "candidates slices need to be all `tensor.extract_slice`s or "
         "`tensor.parallel_insert_slice`s");
   }
@@ -2261,8 +2262,14 @@ mlir::scf::tileAndFuseConsumerOfSlices(
     for (auto candidateSliceOp : clonedInsertSlices) {
       SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
       SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+      SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
 
       // 9. Check all insert stride is 1.
+      if (!llvm::all_of(strides, isOneInteger)) {
+        return rewriter.notifyMatchFailure(
+            candidateSliceOp, "containingOp's result yield with stride");
+      }
+
       allOffsets.emplace_back(std::move(offsets));
       allSizes.emplace_back(std::move(sizes));
     }
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 10831e790d739..4392a2c0eb839 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -17,6 +17,9 @@
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tensor-swap-slices"
 
 using namespace mlir;
 
@@ -43,12 +46,16 @@ FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
     OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
     ArrayRef<OpOperand *> consumerOperands) {
   if (sliceOps.empty()) {
-    return emitError(builder.getUnknownLoc(),
-                     "expected candidate slices list to be non-empty");
+    LLVM_DEBUG(
+        { llvm::dbgs() << "expected candidate slices list to be non-empty"; });
+    return failure();
   }
   if (sliceOps.size() != consumerOperands.size()) {
-    return sliceOps.front()->emitOpError(
-        "expected as many operands as the number of slices passed");
+    LLVM_DEBUG({
+      llvm::dbgs()
+          << "expected as many operands as the number of slices passed";
+    });
+    return failure();
   }
   auto consumerOp =
       dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
@@ -56,8 +63,11 @@ FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
     return failure();
   for (auto opOperand : consumerOperands.drop_front()) {
     if (opOperand->getOwner() != consumerOp) {
-      return consumerOp->emitOpError(
-          "expected all consumer operands to be from the same operation");
+      LLVM_DEBUG({
+        llvm::dbgs()
+            << "expected all consumer operands to be from the same operation";
+      });
+      return failure();
     }
   }
 
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index f58d294d9611a..0f69875d596f1 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -688,7 +688,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// Check that when the given operand tiles are incosistent, tiling fails.
+// Check that when the given operand tiles are inconsistent, tiling fails.
 
 func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
   %c0 = arith.constant 0 : index
@@ -881,6 +881,7 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<
 	      linalg.yield %0: f32
     } -> tensor<?x?xf32>
     scf.forall.in_parallel {
+      // expected-error @below {{failed to fuse consumer of slice}}
       tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
           : tensor<?x?xf32> into tensor<?x?xf32>
       tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
@@ -888,7 +889,6 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<
     }
   }
   %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
-  // expected-error @below {{inconsistent iteration space mapping from offsets/sizes of operands/results}}
   %result = linalg.generic {
       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel"]}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 3c810f6f4accf..ee3eb9522db7e 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -21,6 +21,9 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "test-tiling-interface"
 
 #define GET_OP_CLASSES
 #include "TestTilingInterfaceTransformOps.h.inc"
@@ -182,7 +185,7 @@ static LogicalResult applyFuseConsumer(
         scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
 
     if (failed(fuseConsumerResults))
-      return failure();
+      return slices.front()->emitOpError("failed to fuse consumer of slice");
 
     // Report back the relevant handles to the transform op.
     for (OpOperand *origConsumerOperand :



More information about the Mlir-commits mailing list