[Mlir-commits] [mlir] [mlir][TilingInterface] Handle multi operand consumer fusion. (PR #145193)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jun 22 21:23:04 PDT 2025


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/145193

>From c86e95b1cc1f076a69cc86850174106c3f8c664a Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 18 Jun 2025 17:05:43 -0700
Subject: [PATCH] [mlir][TilingInterface] Handle multi operand consumer fusion.

For consumer fusion cases of this form

```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {

  tensor.parallel_insert_slice ... into %arg0
  tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#0, %0#1)
```

the current consumer fusion that handles one slice at a time cannot
fuse the consumer into the loop, since fusing along one slice will
create and SSA violation on the other use from the `scf.forall`. The
solution is to allow consumer fusion to allow considering multiple
slices at once. This PR changes the `TilingInterface` methods related
to consumer fusion, i.e.

- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`

to allow fusion while considering multiple operands. It is upto the
`TilingInterface` implementation to return an error if a list of tiles
of the operands cannot result in a consistent implementation of the
tiled operation.

The Linalg operation implementation of `TilingInterface` has been
modified to account for these changes and allow cases where operand
tiles that can result in a consistent tiling implementation are handled.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../SCF/Transforms/TileUsingInterface.h       |  22 +-
 .../Dialect/Tensor/Transforms/Transforms.h    |  14 +-
 .../mlir/Interfaces/TilingInterface.td        |  55 ++--
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 182 +++++++----
 .../SCF/Transforms/TileUsingInterface.cpp     | 209 +++++++++----
 .../SwapExtractSliceWithProducerPatterns.cpp  |  47 ++-
 .../tile-and-fuse-consumer.mlir               | 289 ++++++++++++++++++
 .../TestTilingInterfaceTransformOps.cpp       |  47 +--
 .../TestTilingInterfaceTransformOps.td        |   7 +-
 9 files changed, 667 insertions(+), 205 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index f686ae07b9a99..c3b79fc860208 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -319,19 +319,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
                                      TilingInterface consumer,
                                      const SCFTileAndFuseOptions &options);
 
-/// Fuse the consumer of the source of `candidateSliceOp` by computing the
-/// required slice of the consumer in-place.  Note that the method
-/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
-/// value but does not delete the slice operation.
+/// Fuse the consumer `candidateSlices` by computing the required slice of the
+/// consumer in-place. All the entries of `candidateSlices` are expected to map
+/// to the same consumer. The method returns an error if the consumer cannot be
+/// tiled in a manner that is consistent for all the passed slices. Note that
+/// the method replaces the uses of `candidateSlices` with the tiled and fused
+/// consumer value but does not delete the slice operations.
 struct SCFFuseConsumerOfSliceResult {
-  OpOperand *origConsumerOperand; // Original untiled consumer's operand.
-  OpOperand
-      *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
+  // Original untiled consumer operands.
+  SmallVector<OpOperand *> origConsumerOperands;
+  // Tiled and fused consumer operands.
+  SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
   SmallVector<Operation *> tiledOps;
 };
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
-                           MutableArrayRef<LoopLikeOpInterface> loops);
+tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
+                            ArrayRef<Operation *> candidateSlices,
+                            MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 18981337742eb..9860b06348407 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -31,12 +31,16 @@ namespace tensor {
 FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
 
-/// Method to swap an `tensor.insert_slice` with its consumer when the
-/// consumer implements the `TilingInterface`.
+/// Method to swap an `tensor.insert_slice`s with its consumer when the
+/// consumer implements the `TilingInterface`. The size of `sliceOps` and
+/// `consumerOperands` is expected to be the same. Every entry in
+/// `consumerOperands` represents a use of the the corresponding
+/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is
+/// expected to be uses in the same consumer.
 FailureOr<TilingResult>
-replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
-                                    OffsetSizeAndStrideOpInterface sliceOp,
-                                    OpOperand &consumerOp);
+replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
+                                     ArrayRef<tensor::InsertSliceOp> sliceOps,
+                                     ArrayRef<OpOperand *> consumerOperands);
 
 //===----------------------------------------------------------------------===//
 // Populate functions.
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index cdf3d01ce8a84..30d42993f23ec 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the tiled implementation of an operation that uses
-          exactly a tile of the given operand.
+          exactly tiles of the given operands.
 
           This method is required to allow operations to be "tiled and fused"
-          with an (already tiled) producer. Given a tile of the producer, this
-          method generates the tile of the consumer that uses exactly this
-          produced tile. In some sense it is the "reverse" of
+          with an (already tiled) producer. Given tiles of the producer, this
+          method generates the tile of the consumer that uses exactly these
+          produced tiles. In some sense it is the "reverse" of
           `generateResultTileValue`.
-          - `operandNumber` is the result of the producer used by the consumer.
-          - `offsets` is the offset of the slice of the producer result used by
-            the tiled implementation of the consumer.
-          - `sizes` is the size of the slice of the producer result used by the
+          - `operandNumbers` is the list of operands whose tiles are "producers".
+          - `allOffsets` is the offset of the slice of the producer used by the
+            tiled implementation of the consumer.
+          - `allSizes` is the size of the slice of the producer used by the
             consumer.
-          If it is illegal to fuse with a producer along the given operand for
+          If it is illegal to fuse with a producer along the given operand tiles for
           an operation, the implementation should return a failure.
         }],
         /*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
-        /*methodName=*/"getTiledImplementationFromOperandTile",
+        /*methodName=*/"getTiledImplementationFromOperandTiles",
         /*args=*/(ins
           "::mlir::OpBuilder &":$b,
-          "unsigned":$operandNumber,
-          "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
-          "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
+          "::mlir::ArrayRef<unsigned>":$operandNumbers,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           tile of the operand.
 
           This method is required to allow operations to be "tiled and fused"
-          with an (already tiled) producer. Given a tile of an operand,
-          returns the tile of the iteration space that uses this tile.
-          - `operandNumber` is the result of the producer used by the consumer.
-          - `offsets` is the offset of the slice of the producer result used by
+          with an (already tiled) producer. Given tiles of operands,
+          returns the tile of the iteration space that uses these tiles.
+          - `operandNumbers` is the list of operands whose tiles are "produced"
+            by the producer(s).
+          - `allOffsets` is the offset of the slice of the producers used by
             the tiled implementation of the consumer.
-          - `sizes` is the size of the slice of the producer result used by the
+          - `allSizes` is the size of the slice of the producers used by the
             consumer.
-          If it is illegal to fuse with a producer along the given operand for
-          an operation, or if this mapping cannot be computed, the
-          implementation should return a failure.
+          If it is illegal to fuse with the producer slices for an operation,
+          or if this mapping cannot be computed, the implementation should
+          return a failure.
 
           Note that unlike the "tile consumer and fuse producer" case, the
           "tile producer and fuse consumer" requires an additional method to get
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           transformation. It does not provide guarantees on whether such a
           transformation is profitable.
 
-          For most cases `getTiledImplementationFromOperandTile` could be a
-          implemented using `getIterationDomainTileFromOperandTile` +
+          For most cases `getTiledImplementationFromOperandTiles` could be a
+          implemented using `getIterationDomainTileFromOperandTiles` +
           `getTiledImplementation` methods.
         }],
         /*retType=*/"::llvm::LogicalResult",
-        /*methodName=*/"getIterationDomainTileFromOperandTile",
+        /*methodName=*/"getIterationDomainTileFromOperandTiles",
         /*args=*/(ins
           "::mlir::OpBuilder &":$b,
-          "unsigned":$operandNumber,
-          "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
-          "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+          "::mlir::ArrayRef<unsigned>":$operandNumbers,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
           "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
           "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
         /*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..86045b54075bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -147,55 +147,80 @@ struct LinalgOpTilingInterface
   /// Utility to fetch the offsets and sizes when applied as per the indexing
   /// map of the linalg op. This helps in fusing the linalg op as a consumer of
   /// a given slice op.
-  void
-  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
-                         ArrayRef<OpFoldResult> offsets,
-                         ArrayRef<OpFoldResult> sizes,
-                         SmallVectorImpl<OpFoldResult> &mappedOffsets,
-                         SmallVectorImpl<OpFoldResult> &mappedSizes) const {
-    unsigned numLoops = linalgOp.getNumLoops();
-    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
-    mappedOffsets.resize(numLoops);
-    mappedSizes.resize(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
-        mappedOffsets[index] = value.offset;
-        mappedSizes[index] = value.size;
+  static LogicalResult
+  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
+                         ArrayRef<AffineMap> indexingMaps,
+                         ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+                         ArrayRef<SmallVector<OpFoldResult>> allSizes,
+                         SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
+                         SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
+    DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
+
+    for (auto [indexingMap, offsets, sizes] :
+         llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
+      for (auto [resultExpr, offset, size] :
+           llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
+        auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
+        if (!dimExpr)
+          continue;
+        unsigned position = dimExpr.getPosition();
+        auto it = mappedOffsets.find(position);
+        if (it != mappedOffsets.end()) {
+          OpFoldResult seenOffset = it->second;
+          OpFoldResult seenSize = mappedSizes.lookup(position);
+          if (seenOffset != offset || seenSize != size) {
+            return linalgOp->emitOpError(
+                "inconsistent iteration space mapping from offsets/sizes of "
+                "operands/results");
+          }
+        } else {
+          mappedOffsets[position] = offset;
+          mappedSizes[position] = size;
+        }
       }
     }
-    for (const auto &&[index, value] :
-         llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
-      mappedOffsets[dimPosition] = offsets[index];
-      mappedSizes[dimPosition] = sizes[index];
+
+    // Aggregate from the given operand offsets and sizes, or default to
+    // iteration space values.
+    SmallVector<Range> iterationDomain =
+        cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
+    mappedOffsetsVec.resize(iterationDomain.size());
+    mappedSizesVec.resize(iterationDomain.size());
+    for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
+      auto it = mappedOffsets.find(index);
+      if (it != mappedOffsets.end()) {
+        mappedOffsetsVec[index] = it->second;
+        mappedSizesVec[index] = mappedSizes.lookup(index);
+        continue;
+      }
+      mappedOffsetsVec[index] = domain.offset;
+      mappedOffsetsVec[index] = domain.size;
     }
+    return success();
   }
 
   /// Method to return the position of the result tile computed by the tiled
   /// operation.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
       SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
 
-    // Check that the indexing map used for the operand is a projected
-    // permutation. This could be relaxed with a more general approach that can
-    // map the offsets and sizes from the operand to iteration space tiles
-    // (filling in full extent for dimensions not used to access the result).
-    AffineMap indexingMap =
-        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
-    if (!indexingMap.isProjectedPermutation()) {
-      return op->emitError()
-             << "unhandled get iter domain position when operand is not "
-                "accessed using a permuted projection";
+    std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
+        iterationSpaceSizes;
+    SmallVector<AffineMap> indexingMaps =
+        llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+          OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+          return linalgOp.getMatchingIndexingMap(&opOperand);
+        });
+    if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
+                                      allSizes, iterDomainOffsets,
+                                      iterDomainSizes))) {
+      return failure();
     }
-
-    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           iterDomainOffsets, iterDomainSizes);
     return success();
   }
 
@@ -246,8 +271,13 @@ struct LinalgOpTilingInterface
           "accessed using a permuted projection");
     }
 
-    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           iterDomainOffsets, iterDomainSizes);
+    SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
+    SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
+    auto status =
+        getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
+                               {allSizes}, iterDomainOffsets, iterDomainSizes);
+    (void)status;
+    assert(succeeded(status) && "unexpected error in offset calculation");
     return success();
   }
 
@@ -278,12 +308,13 @@ struct LinalgOpTilingInterface
 
   /// Method to generate the tiled implementation of an operation from the tile
   /// of the operand.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
     SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, operandNumber, offsets, sizes, mappedOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
             mappedSizes))) {
       return failure();
     }
@@ -750,13 +781,17 @@ struct PackOpTiling
   /// Method to return the position of iteration domain tile computed by the
   /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
   /// `resultSizes` only cover outer dimensions.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &resultOffsets,
       SmallVectorImpl<OpFoldResult> &resultSizes) const {
-    if (operandNumber != 0)
-      return failure();
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+      return op->emitOpError("unsupporeted operands for consumer fusion");
+
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
 
     auto packOp = cast<PackOp>(op);
     // It is not trivial to infer dest tile from source tile if `packOp` has
@@ -817,11 +852,15 @@ struct PackOpTiling
   }
 
   /// Method to return the tiled implementation of tensor.pack as a consumer.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
-    if (operandNumber != 0)
-      return failure();
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+      return op->emitOpError("unhandled operands for consumer fusion");
+
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
 
     auto packOp = cast<PackOp>(op);
     Location loc = packOp.getLoc();
@@ -836,8 +875,8 @@ struct PackOpTiling
     tiledOperands.push_back(sourceSlice);
 
     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
             outerDimSizes)))
       return failure();
 
@@ -1095,12 +1134,20 @@ struct UnPackOpTiling
 
   /// Method to return the position of iteration domain tile computed by the
   /// tiled operation.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &resultOffsets,
       SmallVectorImpl<OpFoldResult> &resultSizes) const {
+    if (operandNumbers.size() != 1) {
+      return op->emitOpError("unable to handle multiple operands");
+    }
     auto unPackOp = cast<UnPackOp>(op);
+    unsigned operandNumber = operandNumbers[0];
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
+
     // If the operand tile is the dest, then no adjustment is needed.
     if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
       resultOffsets = llvm::to_vector(offsets);
@@ -1154,10 +1201,17 @@ struct UnPackOpTiling
   }
 
   /// Method to return the tiled implementation of tensor.unpack as a consumer.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
+      return op->emitOpError("unhandled operands for consumer fusion");
+    }
     auto unPackOp = cast<UnPackOp>(op);
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
+
     // tensor.unpack op is fusible (as a consumer) only if inner dims are not
     // tiled.
     int64_t numTiles = unPackOp.getInnerDimsPos().size();
@@ -1172,8 +1226,8 @@ struct UnPackOpTiling
     // Fetch offset/size for creating the slice of the dest operand of
     // unpack op.
     SmallVector<OpFoldResult> outputOffsets, outputSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
             outputSizes)))
       return failure();
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 3f29dd3ac5e48..4e19c2770f722 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1977,53 +1977,118 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
 
 /// A utility to fetch an untiled consumer of
 /// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
-                            MutableArrayRef<LoopLikeOpInterface> loops) {
+static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
+    RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
   assert(!loops.empty() && "unexpected empty loops");
-  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
-  } else if (auto parallelInsertSlice =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
-  } else {
-    return failure();
+  assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
+  SmallVector<OpOperand *> fusedOperands;
+  for (auto sliceOp : sliceOps) {
+    FailureOr<OpOperand *> fusedOperand =
+        TypeSwitch<Operation *, FailureOr<OpOperand *>>(sliceOp)
+            .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+                [&](auto op) {
+                  return getUntiledConsumerFromSlice(rewriter, op, loops);
+                })
+            .Default([](Operation *op) {
+              return op->emitOpError("unhandled slice type");
+            });
+    if (failed(fusedOperand)) {
+      return failure();
+    }
+    if (!fusedOperands.empty() &&
+        fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
+      return fusedOperands.front()->getOwner()->emitOpError(
+          "all candidate slices must be to the same consumer");
+    }
+    fusedOperands.push_back(fusedOperand.value());
+  }
+  return fusedOperands;
+}
+
+template <typename InsertSliceOpTy>
+static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
+                                                InsertSliceOpTy sliceOp);
+
+template <>
+tensor::InsertSliceOp
+cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
+                                          tensor::InsertSliceOp insertSliceOp) {
+  return cast<tensor::InsertSliceOp>(
+      rewriter.clone(*insertSliceOp.getOperation()));
+}
+
+template <>
+tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
+  return rewriter.create<tensor::InsertSliceOp>(
+      insertSliceOp->getLoc(), insertSliceOp.getSource(),
+      insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
+      insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+}
+
+static SmallVector<tensor::InsertSliceOp>
+cloneAsInsertSlices(RewriterBase &rewriter,
+                    ArrayRef<Operation *> candidateSlices) {
+  assert(!candidateSlices.empty() &&
+         "unexpected empty list of slices to clone");
+  SmallVector<tensor::InsertSliceOp> clonedSlices;
+  for (auto sliceOp : candidateSlices) {
+    TypeSwitch<Operation *>(sliceOp)
+        .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+            [&](auto op) {
+              auto clonedOp = cloneAsInsertSlice(rewriter, op);
+              clonedSlices.push_back(clonedOp);
+            })
+        .Default([&](Operation *op) {
+          // Assert here assuming this has already been checked.
+          assert(0 && "unexpected slice type while cloning as insert slice");
+        });
   }
+  return clonedSlices;
 }
 
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(
-    RewriterBase &rewriter, Operation *candidateSliceOp,
+mlir::scf::tileAndFuseConsumerOfSlices(
+    RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
     MutableArrayRef<LoopLikeOpInterface> loops) {
+  if (candidateSlices.empty()) {
+    return emitError(rewriter.getUnknownLoc(),
+                     "no candidate slices provided for consumer fusion");
+  }
   // Return if `loops` is empty, return an error for now. Caller is expected
   // to handle this case.
   if (loops.empty()) {
-    return candidateSliceOp->emitOpError(
+    return candidateSlices.front()->emitOpError(
         "cannot call tile and fuse consumer with an empty loop nest");
   }
-  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
-          candidateSliceOp))
-    return failure();
+
+  if (!(llvm::all_of(
+            candidateSlices,
+            [](Operation *op) { return isa<tensor::InsertSliceOp>(op); }) ||
+        llvm::all_of(candidateSlices, [](Operation *op) {
+          return isa<tensor::ParallelInsertSliceOp>(op);
+        }))) {
+    return candidateSlices.front()->emitOpError(
+        "candidates slices need to be all `tensor.extract_slice`s or "
+        "`tensor.parallel_insert_slice`s");
+  }
 
   // 1. Get the consumer of scf.for for the result yielded by
   // tensor.insert_slice/parallel_insert_slice.
-  FailureOr<OpOperand *> maybeConsumerOpOperand =
-      getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
-  if (failed(maybeConsumerOpOperand)) {
-    return rewriter.notifyMatchFailure(candidateSliceOp,
-                                       "could not fetch consumer to fuse");
-  }
-  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
-  Operation *consumerOp = consumerOpOperand->getOwner();
-  unsigned operandNumber = consumerOpOperand->getOperandNumber();
-  unsigned resultNumber = 0;
-  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
-    resultNumber = producerResult.getResultNumber();
-  } else {
-    return rewriter.notifyMatchFailure(
-        consumerOp, "consumer op's operand doesn't seem to be an OpResult");
+  SmallVector<OpOperand *> consumerOpOperands;
+  Operation *consumerOp;
+  {
+    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
+        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+    if (failed(maybeConsumerOpOperand)) {
+      return rewriter.notifyMatchFailure(candidateSlices.front(),
+                                         "could not fetch consumer to fuse");
+    }
+    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
+    consumerOp = consumerOpOperands.front()->getOwner();
   }
 
   LoopLikeOpInterface outerMostLoop = loops.front();
@@ -2043,16 +2108,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
   if (!dstOp)
     return rewriter.notifyMatchFailure(consumerOp,
                                        "consumer op is not DPS operation");
-  SmallVector<Value> dpsInits =
-      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
+  if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
+        return dstOp.isDpsInit(opOperand);
+      })) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  SmallVector<Value> newInits = dpsInits;
-
-  Location loc = outerMostLoop->getLoc();
+  SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
 
   // 3. Move the whole loop structure right before firstUserOfLoop, the
   // dominance should be already ensured by `checkAssumptionForLoop`.
@@ -2067,43 +2130,52 @@ mlir::scf::tileAndFuseConsumerOfSlice(
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
-  tensor::InsertSliceOp clonedInsertSliceOp;
   if (auto sliceOp =
-          dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+          dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
     auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
     rewriter.setInsertionPoint(newForallOp.getTerminator());
-    clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
   } else {
-    rewriter.setInsertionPoint(candidateSliceOp);
-    clonedInsertSliceOp =
-        cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
+    rewriter.setInsertionPoint(candidateSlices.front());
   }
+  // 5.a. Clone all the candidate slices as equivalent insert slice ops.
+  SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
+      cloneAsInsertSlices(rewriter, candidateSlices);
 
-  // 5.a. Clone consumer op.
+  // 5.b. Clone consumer op.
   auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
+  SmallVector<unsigned> operandNumbers =
+      llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
+        return opOperand->getOperandNumber();
+      });
+  SmallVector<OpOperand *> clonedOpFusedOperandsList =
+      llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
+        return &clonedConsumerOp->getOpOperand(operandNum);
+      });
 
-  // 5.b. Replace all uses of the loop result with the result of the cloned
+  // 5.c. Replace all uses of the loop result with the result of the cloned
   // tensor.insert_slice.
-  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
-    operandToReplace.set(clonedInsertSliceOp.getResult());
+    for (auto [operandToReplace, clonedSliceOp] :
+         llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
+      operandToReplace->set(clonedSliceOp.getResult());
+    }
   });
 
   // 6. Perform tiling of the cloned consumer and replace the operand at
   // `operandNumber` with the source of the cloned tensor.insert_slice op.
-  auto ossSliceOp =
-      cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
   FailureOr<TilingResult> tileAndFuseResult =
-      tensor::replaceInsertSliceWithTiledConsumer(
-          rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+      tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
+                                                   clonedOpFusedOperandsList);
   if (failed(tileAndFuseResult)) {
     return failure();
   }
+
   auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
-  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
-                              clonedInsertSliceOp.getSource());
+  for (auto [operandNum, clonedSliceOp] :
+       llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
+    rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
+                                clonedSliceOp.getSource());
+  }
 
   // 7. Reconstruct [nested] loop with new inits.
   YieldTiledValuesFn newYieldValuesFn =
@@ -2115,14 +2187,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
     // 8. Set inner insertPoint right before tiled consumer op.
     innerRewriter.setInsertionPoint(tiledConsumerOp);
 
-    SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+    SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
+    for (auto candidateSliceOp : clonedInsertSlices) {
+      SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+      SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
 
-    // 9. Check all insert stride is 1.
-    if (!llvm::all_of(strides, isOneInteger)) {
-      return rewriter.notifyMatchFailure(
-          candidateSliceOp, "containingOp's result yield with stride");
+      // 9. Check all insert stride is 1.
+      allOffsets.emplace_back(std::move(offsets));
+      allSizes.emplace_back(std::move(sizes));
     }
 
     // 10. Try to get iter domain position from input position. Use
@@ -2132,8 +2204,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(
     // tiledConsumerOp could lead to some chained unnecessary extra index
     // computation.
     SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-    if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-            rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+    if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
+            rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
             iterDomainSizes))) {
       return rewriter.notifyMatchFailure(
           clonedConsumerOp,
@@ -2209,10 +2281,13 @@ mlir::scf::tileAndFuseConsumerOfSlice(
   // 16. Need to erase the old scf loop and the cloned consumer op.
   rewriter.eraseOp(clonedConsumerOp);
 
+  SmallVector<OpOperand *> tiledAndFusedOpOperands =
+      llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
+        return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
+      });
   return scf::SCFFuseConsumerOfSliceResult{
-      consumerOpOperand,
-      &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
-      tileAndFuseResult->tiledOps};
+      std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
+      std::move(tileAndFuseResult->tiledOps)};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 6f33f9b55ceb6..10831e790d739 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -39,21 +39,48 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
   return *tiledResult;
 }
 
-FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
-    OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
-    OpOperand &consumer) {
-  auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
+FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
+    OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
+    ArrayRef<OpOperand *> consumerOperands) {
+  if (sliceOps.empty()) {
+    return emitError(builder.getUnknownLoc(),
+                     "expected candidate slices list to be non-empty");
+  }
+  if (sliceOps.size() != consumerOperands.size()) {
+    return sliceOps.front()->emitOpError(
+        "expected as many operands as the number of slices passed");
+  }
+  auto consumerOp =
+      dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
   if (!consumerOp)
     return failure();
+  for (auto opOperand : consumerOperands.drop_front()) {
+    if (opOperand->getOwner() != consumerOp) {
+      return consumerOp->emitOpError(
+          "expected all consumer operands to be from the same operation");
+    }
+  }
 
-  // `TilingInterface` currently only supports strides being 1.
-  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
-    return failure();
+  auto consumerOperandNums = llvm::map_to_vector(
+      consumerOperands, [](OpOperand *opOperand) -> unsigned {
+        return opOperand->getOperandNumber();
+      });
+  SmallVector<SmallVector<OpFoldResult>> allOffsets;
+  SmallVector<SmallVector<OpFoldResult>> allSizes;
+  for (auto sliceOp : sliceOps) {
+
+    // `TilingInterface` currently only supports strides being 1.
+    if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
+      return failure();
 
+    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
+    allOffsets.emplace_back(std::move(offsets));
+    allSizes.emplace_back(std::move(sizes));
+  }
   FailureOr<TilingResult> tiledResult =
-      consumerOp.getTiledImplementationFromOperandTile(
-          builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
-          sliceOp.getMixedSizes());
+      consumerOp.getTiledImplementationFromOperandTiles(
+          builder, consumerOperandNums, allOffsets, allSizes);
   if (failed(tiledResult))
     return failure();
 
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 77e52946b830f..569f12d5b615e 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -620,3 +620,292 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %loop:2 = scf.forall (%iv0) =  (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+    %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic:2 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	%1 = arith.addf %b0, %b2 : f32
+	linalg.yield %0, %1 : f32, f32
+    } -> (tensor<?xf32>, tensor<?xf32>)
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+      tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+    }	
+  }
+  %empty = tensor.empty(%dim0) : tensor<?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?xf32>
+  return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion_1(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0
+//       CHECK:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+//       CHECK:     %[[TILESIZE:.+]] = affine.min
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//   CHECK-DAG:     %[[GENERIC:.+]]:2 = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC]]#0, %[[GENERIC]]#1 :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:   return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// Check that when the given operand tiles are incosistent, tiling fails.
+
+func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %loop:2 = scf.forall (%iv0) =  (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+    %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	linalg.yield %0 : f32
+    } -> tensor<?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+	%0 = arith.addf %b0, %b1 : f32
+	linalg.yield %0: f32
+    } -> tensor<?xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+    }	
+  }
+  %empty = tensor.empty(%dim0) : tensor<?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?xf32>
+  return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion_2(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0
+//       CHECK:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+//       CHECK:     %[[TILESIZE:.+]] = affine.min
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//       CHECK:     %[[GENERIC1:.+]] = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC0]], %[[GENERIC1]] :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:   return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
+    %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %loop:2 = scf.forall (%iv0, %iv1) =  (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+      shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) {
+    %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+        : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	      iterator_types = ["parallel", "reduction"]}
+	      ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+      	%0 = arith.addf %b0, %b1 : f32
+	      linalg.yield %0: f32
+    } -> tensor<?xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32>
+    }
+  }
+  %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+  // expected-error @below {{unable to apply consumer fusion along these operands since the iteration spaces are inconsistent}}
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @multi_slice_fusion_2(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0
+//       CHECK:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+//       CHECK:     %[[TILESIZE:.+]] = affine.min
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//       CHECK:     %[[GENERIC1:.+]] = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC0]], %[[GENERIC1]] :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:   return %[[RESULT]]#2
+
+// -----
+
+func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
+    %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %loop:2 = scf.forall (%iv0, %iv1) =  (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+      shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+    %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+        : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+      	%0 = arith.addf %b0, %b1 : f32
+	      linalg.yield %0: f32
+    } -> tensor<?x?xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+    }
+  }
+  %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+  // expected-error @below {{unable to apply consumer fusion along these operands since the iteration spaces are inconsistent}}
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 9971f0cde4ed2..3c810f6f4accf 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -168,29 +168,30 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
 
 /// Apply fusing of consumer transformation to all payload ops and store both
 /// the original consumer operation as well as the fused consumer operation.
-template <typename Range>
 static LogicalResult applyFuseConsumer(
-    RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
-    MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse,
-    TransformResults &transformResults) {
+    RewriterBase &rewriter, Operation *transformOp,
+    ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
+    uint32_t numConsumerToFuse, TransformResults &transformResults) {
   SmallVector<Operation *> originalConsumerOps;
   SmallVector<Operation *> fusedConsumerOps;
 
-  for (Operation *target : payloadOps) {
-    rewriter.setInsertionPoint(target);
+  rewriter.setInsertionPoint(slices.front());
 
-    while (numConsumerToFuse--) {
-      FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
-          scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
+  while (numConsumerToFuse--) {
+    FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+        scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
 
-      if (failed(fuseConsumerResults))
-        return failure();
+    if (failed(fuseConsumerResults))
+      return failure();
 
-      // Report back the relevant handles to the transform op.
-      originalConsumerOps.push_back(
-          fuseConsumerResults->origConsumerOperand->getOwner());
-      fusedConsumerOps.push_back(
-          fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+    // Report back the relevant handles to the transform op.
+    for (OpOperand *origConsumerOperand :
+         fuseConsumerResults->origConsumerOperands) {
+      originalConsumerOps.push_back(origConsumerOperand->getOwner());
+    }
+    for (OpOperand *tiledAndFusedConsumerOperand :
+         fuseConsumerResults->tiledAndFusedConsumerOperands) {
+      fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
     }
   }
 
@@ -203,6 +204,12 @@ DiagnosedSilenceableFailure
 transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
                                      TransformResults &transformResults,
                                      TransformState &state) {
+  SmallVector<Operation *> slices;
+  for (auto op : getTargets()) {
+    auto sliceOp = *state.getPayloadOps(op).begin();
+    slices.push_back(sliceOp);
+  }
+
   SmallVector<LoopLikeOpInterface> loops;
   for (auto op : llvm::reverse(getLoops())) {
     auto loopLikeOp =
@@ -212,16 +219,16 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
     }
     loops.push_back(loopLikeOp);
   }
-  LogicalResult result = applyFuseConsumer(
-      rewriter, getOperation(), state.getPayloadOps(getTarget()), loops,
-      getNumConsumerToFuse(), transformResults);
+  LogicalResult result =
+      applyFuseConsumer(rewriter, getOperation(), slices, loops,
+                        getNumConsumerToFuse(), transformResults);
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
 }
 
 void transform::TestFuseConsumerOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  consumesHandle(getTargetMutable(), effects);
+  consumesHandle(getTargetsMutable(), effects);
   consumesHandle(getLoopsMutable(), effects);
   producesHandle(getOperation()->getOpResults(), effects);
   modifiesPayload(effects);
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 98f7145c99cb1..3c09082e192ea 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -50,7 +50,8 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
 }
 
 def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
-       [DeclareOpInterfaceMethods<TransformOpInterface>,
+       [AttrSizedOperandSegments,
+        DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
         ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
@@ -59,14 +60,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
   }];
 
   let arguments = (ins 
-      TransformHandleTypeInterface:$target,
+      Variadic<TransformHandleTypeInterface>:$targets,
       Variadic<TransformHandleTypeInterface>:$loops,
       DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
   let results = (outs TransformHandleTypeInterface:$consumer,
                       TransformHandleTypeInterface:$fused_consumer);
 
   let assemblyFormat = [{
-    $target `in` `(` $loops `)`
+    $targets `in` `(` $loops `)`
     (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? 
     attr-dict `:` functional-type(operands, results)
   }];



More information about the Mlir-commits mailing list