[Mlir-commits] [mlir] c873e5f - [mlir][TilingInterface] Handle multi operand consumer fusion. (#145193)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 25 11:54:42 PDT 2025


Author: MaheshRavishankar
Date: 2025-06-25T11:54:38-07:00
New Revision: c873e5f87d84bab700a297af8cdb76b2bd3ece88

URL: https://github.com/llvm/llvm-project/commit/c873e5f87d84bab700a297af8cdb76b2bd3ece88
DIFF: https://github.com/llvm/llvm-project/commit/c873e5f87d84bab700a297af8cdb76b2bd3ece88.diff

LOG: [mlir][TilingInterface] Handle multi operand consumer fusion. (#145193)

For consumer fusion cases of this form

```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {

  tensor.parallel_insert_slice ... into %arg0
  tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#0, %0#1)
```

the current consumer fusion that handles one slice at a time cannot fuse
the consumer into the loop, since fusing along one slice will create and
SSA violation on the other use from the `scf.forall`. The solution is to
allow consumer fusion to allow considering multiple slices at once. This
PR changes the `TilingInterface` methods related to consumer fusion,
i.e.

- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`

to allow fusion while considering multiple operands. It is upto the
`TilingInterface` implementation to return an error if a list of tiles
of the operands cannot result in a consistent implementation of the
tiled operation.

The Linalg operation implementation of `TilingInterface` has been
modified to account for these changes and allow cases where operand
tiles that can result in a consistent tiling implementation are handled.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
    mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 7b43aa43c7517..3205da6e448fc 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -313,19 +313,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
                                      TilingInterface consumer,
                                      const SCFTileAndFuseOptions &options);
 
-/// Fuse the consumer of the source of `candidateSliceOp` by computing the
-/// required slice of the consumer in-place.  Note that the method
-/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
-/// value but does not delete the slice operation.
+/// Fuse the consumer `candidateSlices` by computing the required slice of the
+/// consumer in-place. All the entries of `candidateSlices` are expected to map
+/// to the same consumer. The method returns an error if the consumer cannot be
+/// tiled in a manner that is consistent for all the passed slices. Note that
+/// the method replaces the uses of `candidateSlices` with the tiled and fused
+/// consumer value but does not delete the slice operations.
 struct SCFFuseConsumerOfSliceResult {
-  OpOperand *origConsumerOperand; // Original untiled consumer's operand.
-  OpOperand
-      *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
+  // Original untiled consumer operands.
+  SmallVector<OpOperand *> origConsumerOperands;
+  // Tiled and fused consumer operands.
+  SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
   SmallVector<Operation *> tiledOps;
 };
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
-                           MutableArrayRef<LoopLikeOpInterface> loops);
+tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
+                            ArrayRef<Operation *> candidateSlices,
+                            MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.

diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 18981337742eb..87deef9ca7466 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -31,12 +31,16 @@ namespace tensor {
 FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
 
-/// Method to swap an `tensor.insert_slice` with its consumer when the
-/// consumer implements the `TilingInterface`.
+/// Method to swap `tensor.insert_slice`s with their consumers when the
+/// consumer implements the `TilingInterface`. The size of `sliceOps` and
+/// `consumerOperands` is expected to be the same. Every entry in
+/// `consumerOperands` represents a use of the the corresponding
+/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is
+/// expected to be uses in the same consumer.
 FailureOr<TilingResult>
-replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
-                                    OffsetSizeAndStrideOpInterface sliceOp,
-                                    OpOperand &consumerOp);
+replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
+                                     ArrayRef<tensor::InsertSliceOp> sliceOps,
+                                     ArrayRef<OpOperand *> consumerOperands);
 
 //===----------------------------------------------------------------------===//
 // Populate functions.

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 31f54413a5ff0..663c256c848df 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
   using PointerUnion<Attribute, Value>::PointerUnion;
 
 public:
-  void dump() const { llvm::errs() << *this << "\n"; }
+  LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }
 
   MLIRContext *getContext() const {
     PointerUnion pu = *this;

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 0de37338c95e4..0c0fc88aec95a 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the tiled implementation of an operation that uses
-          exactly a tile of the given operand.
+          the exact tiles of the given operands.
 
           This method is required to allow operations to be "tiled and fused"
-          with an (already tiled) producer. Given a tile of the producer, this
-          method generates the tile of the consumer that uses exactly this
-          produced tile. In some sense it is the "reverse" of
+          with an (already tiled) producer. Given tiles of the producer, this
+          method generates the tile of the consumer that uses exactly these
+          produced tiles. In some sense it is the "reverse" of
           `generateResultTileValue`.
-          - `operandNumber` is the result of the producer used by the consumer.
-          - `offsets` is the offset of the slice of the producer result used by
-            the tiled implementation of the consumer.
-          - `sizes` is the size of the slice of the producer result used by the
+          - `operandNumbers` is the list of operands whose tiles are "producers".
+          - `allOffsets` is the offset of the slice of the producer used by the
+            tiled implementation of the consumer.
+          - `allSizes` is the size of the slice of the producer used by the
             consumer.
-          If it is illegal to fuse with a producer along the given operand for
+          If it is illegal to fuse with a producer along the given operand tiles for
           an operation, the implementation should return a failure.
         }],
         /*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
-        /*methodName=*/"getTiledImplementationFromOperandTile",
+        /*methodName=*/"getTiledImplementationFromOperandTiles",
         /*args=*/(ins
           "::mlir::OpBuilder &":$b,
-          "unsigned":$operandNumber,
-          "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
-          "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
+          "::mlir::ArrayRef<unsigned>":$operandNumbers,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           tile of the operand.
 
           This method is required to allow operations to be "tiled and fused"
-          with an (already tiled) producer. Given a tile of an operand,
-          returns the tile of the iteration space that uses this tile.
-          - `operandNumber` is the result of the producer used by the consumer.
-          - `offsets` is the offset of the slice of the producer result used by
+          with an (already tiled) producer. Given tiles of operands,
+          returns the tile of the iteration space that uses these tiles.
+          - `operandNumbers` is the list of operands whose tiles are "produced"
+            by the producer(s).
+          - `allOffsets` is the offset of the slice of the producers used by
             the tiled implementation of the consumer.
-          - `sizes` is the size of the slice of the producer result used by the
+          - `allSizes` is the size of the slice of the producers used by the
             consumer.
-          If it is illegal to fuse with a producer along the given operand for
-          an operation, or if this mapping cannot be computed, the
-          implementation should return a failure.
+          If it is illegal to fuse with the producer slices for an operation,
+          or if this mapping cannot be computed, the implementation should
+          return a failure.
 
           Note that unlike the "tile consumer and fuse producer" case, the
           "tile producer and fuse consumer" requires an additional method to get
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           transformation. It does not provide guarantees on whether such a
           transformation is profitable.
 
-          For most cases `getTiledImplementationFromOperandTile` could be a
-          implemented using `getIterationDomainTileFromOperandTile` +
+          For most cases `getTiledImplementationFromOperandTiles` could be a
+          implemented using `getIterationDomainTileFromOperandTiles` +
           `getTiledImplementation` methods.
         }],
         /*retType=*/"::llvm::LogicalResult",
-        /*methodName=*/"getIterationDomainTileFromOperandTile",
+        /*methodName=*/"getIterationDomainTileFromOperandTiles",
         /*args=*/(ins
           "::mlir::OpBuilder &":$b,
-          "unsigned":$operandNumber,
-          "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
-          "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+          "::mlir::ArrayRef<unsigned>":$operandNumbers,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
           "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
           "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
         /*methodBody=*/"",

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 19d484a3bb701..513cecef29b61 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -22,8 +22,11 @@
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/Support/Debug.h"
 #include <optional>
 
+#define DEBUG_TYPE "linalg-tiling-interface-impl"
+
 using namespace mlir;
 using namespace mlir::linalg;
 
@@ -148,55 +151,82 @@ struct LinalgOpTilingInterface
   /// Utility to fetch the offsets and sizes when applied as per the indexing
   /// map of the linalg op. This helps in fusing the linalg op as a consumer of
   /// a given slice op.
-  void
-  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
-                         ArrayRef<OpFoldResult> offsets,
-                         ArrayRef<OpFoldResult> sizes,
-                         SmallVectorImpl<OpFoldResult> &mappedOffsets,
-                         SmallVectorImpl<OpFoldResult> &mappedSizes) const {
-    unsigned numLoops = linalgOp.getNumLoops();
-    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
-    mappedOffsets.resize(numLoops);
-    mappedSizes.resize(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
-        mappedOffsets[index] = value.offset;
-        mappedSizes[index] = value.size;
+  static LogicalResult
+  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
+                         ArrayRef<AffineMap> indexingMaps,
+                         ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+                         ArrayRef<SmallVector<OpFoldResult>> allSizes,
+                         SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
+                         SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
+    DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
+
+    for (auto [indexingMap, offsets, sizes] :
+         llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
+      for (auto [resultExpr, offset, size] :
+           llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
+        auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
+        if (!dimExpr)
+          continue;
+        unsigned position = dimExpr.getPosition();
+        auto it = mappedOffsets.find(position);
+        if (it != mappedOffsets.end()) {
+          OpFoldResult seenOffset = it->second;
+          OpFoldResult seenSize = mappedSizes.lookup(position);
+          if (seenOffset != offset || seenSize != size) {
+            LLVM_DEBUG({
+              llvm::dbgs() << "inconsistent iteration space mapping from "
+                              "offsets/sizes of operands/results";
+            });
+            return failure();
+          }
+        } else {
+          mappedOffsets[position] = offset;
+          mappedSizes[position] = size;
+        }
       }
     }
-    for (const auto &&[index, value] :
-         llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
-      mappedOffsets[dimPosition] = offsets[index];
-      mappedSizes[dimPosition] = sizes[index];
+
+    // Aggregate from the given operand offsets and sizes, or default to
+    // iteration space values.
+    SmallVector<Range> iterationDomain =
+        cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
+    mappedOffsetsVec.resize(iterationDomain.size());
+    mappedSizesVec.resize(iterationDomain.size());
+    for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
+      auto it = mappedOffsets.find(index);
+      if (it != mappedOffsets.end()) {
+        mappedOffsetsVec[index] = it->second;
+        mappedSizesVec[index] = mappedSizes.lookup(index);
+        continue;
+      }
+      mappedOffsetsVec[index] = domain.offset;
+      mappedSizesVec[index] = domain.size;
     }
+    return success();
   }
 
   /// Method to return the position of the result tile computed by the tiled
   /// operation.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
       SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
 
-    // Check that the indexing map used for the operand is a projected
-    // permutation. This could be relaxed with a more general approach that can
-    // map the offsets and sizes from the operand to iteration space tiles
-    // (filling in full extent for dimensions not used to access the result).
-    AffineMap indexingMap =
-        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
-    if (!indexingMap.isProjectedPermutation()) {
-      return op->emitError()
-             << "unhandled get iter domain position when operand is not "
-                "accessed using a permuted projection";
+    std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
+        iterationSpaceSizes;
+    SmallVector<AffineMap> indexingMaps =
+        llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+          OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+          return linalgOp.getMatchingIndexingMap(&opOperand);
+        });
+    if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
+                                      allSizes, iterDomainOffsets,
+                                      iterDomainSizes))) {
+      return failure();
     }
-
-    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           iterDomainOffsets, iterDomainSizes);
     return success();
   }
 
@@ -247,8 +277,13 @@ struct LinalgOpTilingInterface
           "accessed using a permuted projection");
     }
 
-    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           iterDomainOffsets, iterDomainSizes);
+    SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
+    SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
+    auto status =
+        getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
+                               {allSizes}, iterDomainOffsets, iterDomainSizes);
+    (void)status;
+    assert(succeeded(status) && "unexpected error in offset calculation");
     return success();
   }
 
@@ -279,12 +314,13 @@ struct LinalgOpTilingInterface
 
   /// Method to generate the tiled implementation of an operation from the tile
   /// of the operand.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
     SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, operandNumber, offsets, sizes, mappedOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
             mappedSizes))) {
       return failure();
     }
@@ -837,13 +873,20 @@ struct PackOpTiling
   /// Method to return the position of iteration domain tile computed by the
   /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
   /// `resultSizes` only cover outer dimensions.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &resultOffsets,
       SmallVectorImpl<OpFoldResult> &resultSizes) const {
-    if (operandNumber != 0)
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
+      LLVM_DEBUG(
+          { llvm::dbgs() << "unsupported operands for consumer fusion"; });
       return failure();
+    }
+
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
 
     auto packOp = cast<PackOp>(op);
     // It is not trivial to infer dest tile from source tile if `packOp` has
@@ -904,11 +947,18 @@ struct PackOpTiling
   }
 
   /// Method to return the tiled implementation of tensor.pack as a consumer.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
-    if (operandNumber != 0)
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
+      LLVM_DEBUG(
+          { llvm ::dbgs() << "unhandled operands for consumer fusion"; });
       return failure();
+    }
+
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
 
     auto packOp = cast<PackOp>(op);
     Location loc = packOp.getLoc();
@@ -923,8 +973,8 @@ struct PackOpTiling
     tiledOperands.push_back(sourceSlice);
 
     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
             outerDimSizes)))
       return failure();
 
@@ -1182,12 +1232,21 @@ struct UnPackOpTiling
 
   /// Method to return the position of iteration domain tile computed by the
   /// tiled operation.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &resultOffsets,
       SmallVectorImpl<OpFoldResult> &resultSizes) const {
+    if (operandNumbers.size() != 1) {
+      LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; });
+      return failure();
+    }
     auto unPackOp = cast<UnPackOp>(op);
+    unsigned operandNumber = operandNumbers[0];
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
+
     // If the operand tile is the dest, then no adjustment is needed.
     if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
       resultOffsets = llvm::to_vector(offsets);
@@ -1241,10 +1300,18 @@ struct UnPackOpTiling
   }
 
   /// Method to return the tiled implementation of tensor.unpack as a consumer.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
+      LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; });
+      return failure();
+    }
     auto unPackOp = cast<UnPackOp>(op);
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
+
     // tensor.unpack op is fusible (as a consumer) only if inner dims are not
     // tiled.
     int64_t numTiles = unPackOp.getInnerDimsPos().size();
@@ -1259,8 +1326,8 @@ struct UnPackOpTiling
     // Fetch offset/size for creating the slice of the dest operand of
     // unpack op.
     SmallVector<OpFoldResult> outputOffsets, outputSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
             outputSizes)))
       return failure();
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index ddcae8481a5b4..995120ad8680e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -2047,53 +2047,119 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
 
 /// A utility to fetch an untiled consumer of
 /// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
-                            MutableArrayRef<LoopLikeOpInterface> loops) {
+static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
+    RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
   assert(!loops.empty() && "unexpected empty loops");
-  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
-  } else if (auto parallelInsertSlice =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
-  } else {
-    return failure();
+  assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
+  SmallVector<OpOperand *> fusedOperands;
+  for (auto sliceOp : sliceOps) {
+    FailureOr<OpOperand *> fusedOperand =
+        TypeSwitch<Operation *, FailureOr<OpOperand *>>(sliceOp)
+            .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+                [&](auto op) {
+                  return getUntiledConsumerFromSlice(rewriter, op, loops);
+                })
+            .Default([&](Operation *op) {
+              return rewriter.notifyMatchFailure(op, "unhandled slice type");
+            });
+    if (failed(fusedOperand)) {
+      return failure();
+    }
+    if (!fusedOperands.empty() &&
+        fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
+      return rewriter.notifyMatchFailure(
+          fusedOperand.value()->getOwner(),
+          "all candidate slices must be to the same consumer");
+    }
+    fusedOperands.push_back(fusedOperand.value());
   }
+  return fusedOperands;
+}
+
+template <typename InsertSliceOpTy>
+static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
+                                                InsertSliceOpTy sliceOp);
+
+template <>
+tensor::InsertSliceOp
+cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
+                                          tensor::InsertSliceOp insertSliceOp) {
+  return cast<tensor::InsertSliceOp>(
+      rewriter.clone(*insertSliceOp.getOperation()));
+}
+
+template <>
+tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
+  return rewriter.create<tensor::InsertSliceOp>(
+      insertSliceOp->getLoc(), insertSliceOp.getSource(),
+      insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
+      insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+}
+
+static SmallVector<tensor::InsertSliceOp>
+cloneAsInsertSlices(RewriterBase &rewriter,
+                    ArrayRef<Operation *> candidateSlices) {
+  assert(!candidateSlices.empty() &&
+         "unexpected empty list of slices to clone");
+  SmallVector<tensor::InsertSliceOp> clonedSlices;
+  for (auto sliceOp : candidateSlices) {
+    TypeSwitch<Operation *>(sliceOp)
+        .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+            [&](auto op) {
+              auto clonedOp = cloneAsInsertSlice(rewriter, op);
+              clonedSlices.push_back(clonedOp);
+            })
+        .Default([&](Operation *op) {
+          // Assert here assuming this has already been checked.
+          assert(0 && "unexpected slice type while cloning as insert slice");
+        });
+  }
+  return clonedSlices;
 }
 
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(
-    RewriterBase &rewriter, Operation *candidateSliceOp,
+mlir::scf::tileAndFuseConsumerOfSlices(
+    RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
     MutableArrayRef<LoopLikeOpInterface> loops) {
+  if (candidateSlices.empty()) {
+    return rewriter.notifyMatchFailure(
+        rewriter.getUnknownLoc(),
+        "no candidate slices provided for consumer fusion");
+  }
   // Return if `loops` is empty, return an error for now. Caller is expected
   // to handle this case.
   if (loops.empty()) {
-    return candidateSliceOp->emitOpError(
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
         "cannot call tile and fuse consumer with an empty loop nest");
   }
-  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
-          candidateSliceOp))
-    return failure();
+
+  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
+        llvm::all_of(candidateSlices,
+                     llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
+        "candidates slices need to be all `tensor.extract_slice`s or "
+        "`tensor.parallel_insert_slice`s");
+  }
 
   // 1. Get the consumer of scf.for for the result yielded by
   // tensor.insert_slice/parallel_insert_slice.
-  FailureOr<OpOperand *> maybeConsumerOpOperand =
-      getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
-  if (failed(maybeConsumerOpOperand)) {
-    return rewriter.notifyMatchFailure(candidateSliceOp,
-                                       "could not fetch consumer to fuse");
-  }
-  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
-  Operation *consumerOp = consumerOpOperand->getOwner();
-  unsigned operandNumber = consumerOpOperand->getOperandNumber();
-  unsigned resultNumber = 0;
-  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
-    resultNumber = producerResult.getResultNumber();
-  } else {
-    return rewriter.notifyMatchFailure(
-        consumerOp, "consumer op's operand doesn't seem to be an OpResult");
+  SmallVector<OpOperand *> consumerOpOperands;
+  Operation *consumerOp;
+  {
+    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
+        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+    if (failed(maybeConsumerOpOperand)) {
+      return rewriter.notifyMatchFailure(candidateSlices.front(),
+                                         "could not fetch consumer to fuse");
+    }
+    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
+    consumerOp = consumerOpOperands.front()->getOwner();
   }
 
   LoopLikeOpInterface outerMostLoop = loops.front();
@@ -2113,16 +2179,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
   if (!dstOp)
     return rewriter.notifyMatchFailure(consumerOp,
                                        "consumer op is not DPS operation");
-  SmallVector<Value> dpsInits =
-      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
+  if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
+        return dstOp.isDpsInit(opOperand);
+      })) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  SmallVector<Value> newInits = dpsInits;
-
-  Location loc = outerMostLoop->getLoc();
+  SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
 
   // 3. Move the whole loop structure right before firstUserOfLoop, the
   // dominance should be already ensured by `checkAssumptionForLoop`.
@@ -2137,43 +2201,52 @@ mlir::scf::tileAndFuseConsumerOfSlice(
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
-  tensor::InsertSliceOp clonedInsertSliceOp;
   if (auto sliceOp =
-          dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+          dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
     auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
     rewriter.setInsertionPoint(newForallOp.getTerminator());
-    clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
   } else {
-    rewriter.setInsertionPoint(candidateSliceOp);
-    clonedInsertSliceOp =
-        cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
+    rewriter.setInsertionPoint(candidateSlices.front());
   }
+  // 5.a. Clone all the candidate slices as equivalent insert slice ops.
+  SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
+      cloneAsInsertSlices(rewriter, candidateSlices);
 
-  // 5.a. Clone consumer op.
+  // 5.b. Clone consumer op.
   auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
+  SmallVector<unsigned> operandNumbers =
+      llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
+        return opOperand->getOperandNumber();
+      });
+  SmallVector<OpOperand *> clonedOpFusedOperandsList =
+      llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
+        return &clonedConsumerOp->getOpOperand(operandNum);
+      });
 
-  // 5.b. Replace all uses of the loop result with the result of the cloned
+  // 5.c. Replace all uses of the loop result with the result of the cloned
   // tensor.insert_slice.
-  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
-    operandToReplace.set(clonedInsertSliceOp.getResult());
+    for (auto [operandToReplace, clonedSliceOp] :
+         llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
+      operandToReplace->set(clonedSliceOp.getResult());
+    }
   });
 
   // 6. Perform tiling of the cloned consumer and replace the operand at
   // `operandNumber` with the source of the cloned tensor.insert_slice op.
-  auto ossSliceOp =
-      cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
   FailureOr<TilingResult> tileAndFuseResult =
-      tensor::replaceInsertSliceWithTiledConsumer(
-          rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+      tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
+                                                   clonedOpFusedOperandsList);
   if (failed(tileAndFuseResult)) {
     return failure();
   }
+
   auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
-  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
-                              clonedInsertSliceOp.getSource());
+  for (auto [operandNum, clonedSliceOp] :
+       llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
+    rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
+                                clonedSliceOp.getSource());
+  }
 
   // 7. Reconstruct [nested] loop with new inits.
   YieldTiledValuesFn newYieldValuesFn =
@@ -2185,14 +2258,20 @@ mlir::scf::tileAndFuseConsumerOfSlice(
     // 8. Set inner insertPoint right before tiled consumer op.
     innerRewriter.setInsertionPoint(tiledConsumerOp);
 
-    SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+    SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
+    for (auto candidateSliceOp : clonedInsertSlices) {
+      SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+      SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+      SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
 
-    // 9. Check all insert stride is 1.
-    if (!llvm::all_of(strides, isOneInteger)) {
-      return rewriter.notifyMatchFailure(
-          candidateSliceOp, "containingOp's result yield with stride");
+      // 9. Check all insert stride is 1.
+      if (!llvm::all_of(strides, isOneInteger)) {
+        return rewriter.notifyMatchFailure(
+            candidateSliceOp, "containingOp's result yield with stride");
+      }
+
+      allOffsets.emplace_back(std::move(offsets));
+      allSizes.emplace_back(std::move(sizes));
     }
 
     // 10. Try to get iter domain position from input position. Use
@@ -2202,8 +2281,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(
     // tiledConsumerOp could lead to some chained unnecessary extra index
     // computation.
     SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-    if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-            rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+    if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
+            rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
             iterDomainSizes))) {
       return rewriter.notifyMatchFailure(
           clonedConsumerOp,
@@ -2279,10 +2358,13 @@ mlir::scf::tileAndFuseConsumerOfSlice(
   // 16. Need to erase the old scf loop and the cloned consumer op.
   rewriter.eraseOp(clonedConsumerOp);
 
+  SmallVector<OpOperand *> tiledAndFusedOpOperands =
+      llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
+        return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
+      });
   return scf::SCFFuseConsumerOfSliceResult{
-      consumerOpOperand,
-      &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
-      tileAndFuseResult->tiledOps};
+      std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
+      std::move(tileAndFuseResult->tiledOps)};
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 6f33f9b55ceb6..4392a2c0eb839 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -17,6 +17,9 @@
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tensor-swap-slices"
 
 using namespace mlir;
 
@@ -39,21 +42,55 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
   return *tiledResult;
 }
 
-FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
-    OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
-    OpOperand &consumer) {
-  auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
+FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
+    OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
+    ArrayRef<OpOperand *> consumerOperands) {
+  if (sliceOps.empty()) {
+    LLVM_DEBUG(
+        { llvm::dbgs() << "expected candidate slices list to be non-empty"; });
+    return failure();
+  }
+  if (sliceOps.size() != consumerOperands.size()) {
+    LLVM_DEBUG({
+      llvm::dbgs()
+          << "expected as many operands as the number of slices passed";
+    });
+    return failure();
+  }
+  auto consumerOp =
+      dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
   if (!consumerOp)
     return failure();
+  for (auto opOperand : consumerOperands.drop_front()) {
+    if (opOperand->getOwner() != consumerOp) {
+      LLVM_DEBUG({
+        llvm::dbgs()
+            << "expected all consumer operands to be from the same operation";
+      });
+      return failure();
+    }
+  }
 
-  // `TilingInterface` currently only supports strides being 1.
-  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
-    return failure();
+  auto consumerOperandNums = llvm::map_to_vector(
+      consumerOperands, [](OpOperand *opOperand) -> unsigned {
+        return opOperand->getOperandNumber();
+      });
+  SmallVector<SmallVector<OpFoldResult>> allOffsets;
+  SmallVector<SmallVector<OpFoldResult>> allSizes;
+  for (auto sliceOp : sliceOps) {
+
+    // `TilingInterface` currently only supports strides being 1.
+    if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
+      return failure();
 
+    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
+    allOffsets.emplace_back(std::move(offsets));
+    allSizes.emplace_back(std::move(sizes));
+  }
   FailureOr<TilingResult> tiledResult =
-      consumerOp.getTiledImplementationFromOperandTile(
-          builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
-          sliceOp.getMixedSizes());
+      consumerOp.getTiledImplementationFromOperandTiles(
+          builder, consumerOperandNums, allOffsets, allSizes);
   if (failed(tiledResult))
     return failure();
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 572a2ae70e0a4..5bdb5073ee865 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -653,6 +653,7 @@ module {
       %5 = affine.min #map2(%i)[%d0, %idx]
       %6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
 
+      // CHECK: linalg.generic
       // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
       // CHECK: %[[T2:.*]] = linalg.generic {{.*}}
       %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 77e52946b830f..0f69875d596f1 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
 
 #map = affine_map<(d0) -> (d0)>
 module {
@@ -620,3 +620,294 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %loop:2 = scf.forall (%iv0) =  (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+    %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic:2 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	%1 = arith.addf %b0, %b2 : f32
+	linalg.yield %0, %1 : f32, f32
+    } -> (tensor<?xf32>, tensor<?xf32>)
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+      tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+    }	
+  }
+  %empty = tensor.empty(%dim0) : tensor<?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?xf32>
+  return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion1(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0
+//       CHECK:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//       CHECK:     %[[TILESIZE:.+]] = affine.min
+//   CHECK-DAG:     %[[GENERIC:.+]]:2 = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC]]#0, %[[GENERIC]]#1 :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:   return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// Check that when the given operand tiles are inconsistent, tiling fails.
+
+func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %loop:2 = scf.forall (%iv0) =  (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+    %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	linalg.yield %0 : f32
+    } -> tensor<?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	iterator_types = ["parallel", "reduction"]}
+	ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+	%0 = arith.addf %b0, %b1 : f32
+	linalg.yield %0: f32
+    } -> tensor<?xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+    }	
+  }
+  %empty = tensor.empty(%dim0) : tensor<?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?xf32>
+  return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion2(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0
+//       CHECK:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//       CHECK:     %[[TILESIZE:.+]] = affine.min
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//       CHECK:     %[[GENERIC1:.+]] = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC0]], %[[GENERIC1]] :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+//       CHECK:   return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
+    %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %loop:2 = scf.forall (%iv0, %iv1) =  (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+      shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) {
+    %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+        : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+	      iterator_types = ["parallel", "reduction"]}
+	      ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+      	%0 = arith.addf %b0, %b1 : f32
+	      linalg.yield %0: f32
+    } -> tensor<?xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32>
+    }
+  }
+  %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @multi_slice_fusion_with_broadcast(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1
+//   CHECK-DAG:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]])
+//       CHECK:   %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+//  CHECK-SAME:       , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+//   CHECK-DAG:     %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]])
+//   CHECK-DAG:     %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]])
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//       CHECK:     %[[GENERIC1:.+]] = linalg.generic
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
+//       CHECK:     %[[FUSED:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC0]], %[[GENERIC1]] :
+//       CHECK:     tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
+//       CHECK:   return %[[RESULT]]#2
+
+// -----
+
+func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
+    %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %loop:2 = scf.forall (%iv0, %iv1) =  (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+      shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+    %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+    %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+    %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+        : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic0 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.mulf %b0, %b1 : f32
+	      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    %init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+        : tensor<?x?xf32> to tensor<?x?xf32>
+    %generic1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+	      iterator_types = ["parallel", "parallel", "reduction"]}
+	      ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+      	%0 = arith.addf %b0, %b1 : f32
+	      linalg.yield %0: f32
+    } -> tensor<?x?xf32>
+    scf.forall.in_parallel {
+      // expected-error @below {{failed to fuse consumer of slice}}
+      tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+      tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+    }
+  }
+  %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %0 = arith.addf %b0, %b1 : f32
+      linalg.yield %0 : f32
+  } -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
+      : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 9971f0cde4ed2..ee3eb9522db7e 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -21,6 +21,9 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "test-tiling-interface"
 
 #define GET_OP_CLASSES
 #include "TestTilingInterfaceTransformOps.h.inc"
@@ -168,29 +171,30 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
 
 /// Apply fusing of consumer transformation to all payload ops and store both
 /// the original consumer operation as well as the fused consumer operation.
-template <typename Range>
 static LogicalResult applyFuseConsumer(
-    RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
-    MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse,
-    TransformResults &transformResults) {
+    RewriterBase &rewriter, Operation *transformOp,
+    ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
+    uint32_t numConsumerToFuse, TransformResults &transformResults) {
   SmallVector<Operation *> originalConsumerOps;
   SmallVector<Operation *> fusedConsumerOps;
 
-  for (Operation *target : payloadOps) {
-    rewriter.setInsertionPoint(target);
+  rewriter.setInsertionPoint(slices.front());
 
-    while (numConsumerToFuse--) {
-      FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
-          scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
+  while (numConsumerToFuse--) {
+    FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+        scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
 
-      if (failed(fuseConsumerResults))
-        return failure();
+    if (failed(fuseConsumerResults))
+      return slices.front()->emitOpError("failed to fuse consumer of slice");
 
-      // Report back the relevant handles to the transform op.
-      originalConsumerOps.push_back(
-          fuseConsumerResults->origConsumerOperand->getOwner());
-      fusedConsumerOps.push_back(
-          fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+    // Report back the relevant handles to the transform op.
+    for (OpOperand *origConsumerOperand :
+         fuseConsumerResults->origConsumerOperands) {
+      originalConsumerOps.push_back(origConsumerOperand->getOwner());
+    }
+    for (OpOperand *tiledAndFusedConsumerOperand :
+         fuseConsumerResults->tiledAndFusedConsumerOperands) {
+      fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
     }
   }
 
@@ -203,6 +207,12 @@ DiagnosedSilenceableFailure
 transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
                                      TransformResults &transformResults,
                                      TransformState &state) {
+  SmallVector<Operation *> slices;
+  for (auto op : getTargets()) {
+    auto sliceOp = *state.getPayloadOps(op).begin();
+    slices.push_back(sliceOp);
+  }
+
   SmallVector<LoopLikeOpInterface> loops;
   for (auto op : llvm::reverse(getLoops())) {
     auto loopLikeOp =
@@ -212,16 +222,16 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
     }
     loops.push_back(loopLikeOp);
   }
-  LogicalResult result = applyFuseConsumer(
-      rewriter, getOperation(), state.getPayloadOps(getTarget()), loops,
-      getNumConsumerToFuse(), transformResults);
+  LogicalResult result =
+      applyFuseConsumer(rewriter, getOperation(), slices, loops,
+                        getNumConsumerToFuse(), transformResults);
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
 }
 
 void transform::TestFuseConsumerOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  consumesHandle(getTargetMutable(), effects);
+  consumesHandle(getTargetsMutable(), effects);
   consumesHandle(getLoopsMutable(), effects);
   producesHandle(getOperation()->getOpResults(), effects);
   modifiesPayload(effects);

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 98f7145c99cb1..3c09082e192ea 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -50,7 +50,8 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
 }
 
 def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
-       [DeclareOpInterfaceMethods<TransformOpInterface>,
+       [AttrSizedOperandSegments,
+        DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
         ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
@@ -59,14 +60,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
   }];
 
   let arguments = (ins 
-      TransformHandleTypeInterface:$target,
+      Variadic<TransformHandleTypeInterface>:$targets,
       Variadic<TransformHandleTypeInterface>:$loops,
       DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
   let results = (outs TransformHandleTypeInterface:$consumer,
                       TransformHandleTypeInterface:$fused_consumer);
 
   let assemblyFormat = [{
-    $target `in` `(` $loops `)`
+    $targets `in` `(` $loops `)`
     (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? 
     attr-dict `:` functional-type(operands, results)
   }];


        


More information about the Mlir-commits mailing list