[Mlir-commits] [mlir] [mlir][TilingInterface] Make `tileAndFuseConsumerOfSlice` take surrounding loops as an argument. (PR #132082)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 21 13:33:10 PDT 2025


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/132082

>From 66fc41978f0f479c46b8be3ec3a70f95d74a838f Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 19 Mar 2025 11:36:08 -0700
Subject: [PATCH 1/2] [mlir][TilingInterface] Make `tileAndFuseConsumerOfSlice`
 take surrounding loops as an argument.

This gets the consumer fusion method in sync with the corresponding
producer fusion method `tileAndFuseProducerOfSlice`. Not taking this
as input required use of complicated analysis to retrieve the
surrounding loops which are very fragile. Just like the producer
fusion method, the loops need to be taken in as an argument, with
typically the loops being created by the tiling methods.

Some utilities are added to check that the loops passed in are
perfectly nested (in the case of an `scf.for` loop nest.

This is change 1 of N to simplify the implementation of tile and fuse
consumers.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../SCF/Transforms/TileUsingInterface.h       |   3 +-
 .../SCF/Transforms/TileUsingInterface.cpp     | 152 ++++++++++++------
 2 files changed, 107 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index d2cddfe00ac78..33a43ce2ee7bb 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -328,7 +328,8 @@ struct SCFFuseConsumerOfSliceResult {
   SmallVector<Operation *> tiledOps;
 };
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
+                           MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index af87fb7a79d04..4fd10b0e30ab0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1890,25 +1890,81 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
   return {nestLoops.rbegin(), nestLoops.rend()};
 }
 
+/// Check that the loop is perfectly nested.
+static bool
+isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "unexpected empty loop nest");
+  if (loops.size() == 1) {
+    return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
+  }
+  for (auto [outerLoop, innerLoop] :
+       llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+    auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
+    auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
+    if (!outerFor || !innerFor) {
+      return false;
+    }
+    auto outerBBArgs = outerFor.getRegionIterArgs();
+    auto innerIterArgs = innerFor.getInitArgs();
+    if (outerBBArgs.size() != innerIterArgs.size()) {
+      return false;
+    }
+
+    for (auto [outerBBArg, innerIterArg] :
+         llvm::zip(outerBBArgs, innerIterArgs)) {
+      if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
+          innerIterArg != outerBBArg) {
+        return false;
+      }
+    }
+
+    auto outerYields =
+        cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
+    auto innerResults = innerFor.getResults();
+    if (outerYields.size() != innerResults.size()) {
+      return false;
+    }
+    for (auto [outerYield, innerResult] :
+         llvm::zip(outerYields, innerResults)) {
+      if (!llvm::hasSingleElement(innerResult.getUses()) ||
+          outerYield != innerResult) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
 /// tensor.insert_slice. This function makes the following assumptions :
 /// 1.  tensor.insert_slice has scf.yield as its only user.
 /// 2.  scf.for's corresponding result has only one use.
 static FailureOr<OpOperand *>
 getUntiledConsumerFromSlice(RewriterBase &rewriter,
-                            tensor::InsertSliceOp candidateSliceOp) {
+                            tensor::InsertSliceOp candidateSliceOp,
+                            MutableArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "unexpected loops to be empty");
+  // 1. Expect slice to be part of the body of the inner most loop.
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  if (containingOp != loops.back()) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp,
+        "expected slice to be within body of inner-most loop");
+  }
+
+  if (!isPerfectlyNestedForLoops(loops)) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "expected passed loops to be perfectly nested.");
+  }
+
   if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
     return failure();
   Value sliceResult = candidateSliceOp.getResult();
   // Step 1. Fetch the corresponding output.
   OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
   unsigned resultNumber = yieldOpOperand.getOperandNumber();
-  // Step 2. Check containing op is scf.for.
-  Operation *containingOp = candidateSliceOp->getParentOp();
-  auto forOp = dyn_cast<scf::ForOp>(containingOp);
-  if (!forOp)
-    return failure();
-  scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
+
+  scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
 
   return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
 }
@@ -1917,35 +1973,49 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
 /// by a tensor.parallel_insert_slice.
 static FailureOr<OpOperand *>
 getUntiledConsumerFromSlice(RewriterBase &rewriter,
-                            tensor::ParallelInsertSliceOp candidateSliceOp) {
-  // Step 1. Fetch the corresponding output
+                            tensor::ParallelInsertSliceOp candidateSliceOp,
+                            MutableArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "unexpected loops to be empty");
+  // 1. Check that the surrounding loop is a single scf.forall loop.
+  if (loops.size() != 1) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "expected single surrounding scf.forall");
+  }
+  auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
+  if (!forallOp) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "expected single surrounding scf.forall");
+  }
+
+  // 2. Fetch the corresponding output
   Value sliceDest = candidateSliceOp.getDest();
   auto iterArg = dyn_cast<BlockArgument>(sliceDest);
   if (!iterArg)
     return failure();
-  Operation *containingOp = iterArg.getOwner()->getParentOp();
-  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
-    return failure();
-  // Step 2. Check that the containing op is scf.forall.
-  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
-  if (!forallOp)
+  if (iterArg.getOwner()->getParentOp() != forallOp)
     return failure();
+
   unsigned resultNumber =
       forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
           .getResultNumber();
 
-  return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
+  return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
 }
 
 /// A utility to fetch an untiled consumer of
 /// tensor.insert_slice/tensor.parallel_insert_slice.
 static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
+getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
+                            MutableArrayRef<LoopLikeOpInterface> loops) {
+  if (loops.empty()) {
+    return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops");
+  }
+
   if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(rewriter, insertSlice);
+    return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
   } else if (auto parallelInsertSlice =
                  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
+    return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
   } else {
     return failure();
   }
@@ -1954,18 +2024,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
-                                      Operation *candidateSliceOp) {
+mlir::scf::tileAndFuseConsumerOfSlice(
+    RewriterBase &rewriter, Operation *candidateSliceOp,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
+  // Return if `loops` is empty, return an error for now. Caller is expected
+  // to handle this case.
+  if (loops.empty()) {
+    return candidateSliceOp->emitOpError(
+        "cannot call tile and fuse consumer with an empty loop nest");
+  }
   if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
           candidateSliceOp))
     return failure();
 
-  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
-
   // 1. Get the consumer of scf.for for the result yielded by
   // tensor.insert_slice/parallel_insert_slice.
   FailureOr<OpOperand *> maybeConsumerOpOperand =
-      getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
+      getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
   if (failed(maybeConsumerOpOperand)) {
     return rewriter.notifyMatchFailure(candidateSliceOp,
                                        "could not fetch consumer to fuse");
@@ -1981,25 +2056,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
   }
 
-  // There are two possible cases regarding `oldLoopOp` here:
-  // 1. single `scf.forall` or `scf.for`.
-  // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
-  // top-level loop is the outer-most one of these nested loops.
-  LoopLikeOpInterface innerMostLoop =
-      candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
-  SmallVector<LoopLikeOpInterface> nestedLoops;
-  if (isInsertSliceOp) {
-    nestedLoops = llvm::map_to_vector(
-        getPerfectlyNestedLoopsOutsideOf(
-            cast<scf::ForOp>(innerMostLoop.getOperation())),
-        [](scf::ForOp forOp) {
-          return cast<LoopLikeOpInterface>(forOp.getOperation());
-        });
-  } else {
-    nestedLoops = {innerMostLoop};
-  }
-
-  LoopLikeOpInterface outerMostLoop = nestedLoops.front();
+  LoopLikeOpInterface outerMostLoop = loops.front();
+  LoopLikeOpInterface innerMostLoop = loops.back();
 
   // Check assumption for loop with `reorderOperations` disabled.
   if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
@@ -2165,7 +2223,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
     return success();
   };
   // 14. Add new inits to [nested] loops.
-  if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
+  if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
                                        newYieldValuesFn))) {
     return rewriter.notifyMatchFailure(tiledConsumerOp,
                                        "unable to add new inits to nest loop");
@@ -2174,9 +2232,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   // 15. Replace the result of scf loop and consumer op with new loop's
   // results.
 
-  for (auto &&[oldResult, newResult] : llvm::zip(
-           consumerOp->getResults(),
-           nestedLoops.front()->getResults().take_back(newInits.size()))) {
+  for (auto &&[oldResult, newResult] :
+       llvm::zip(consumerOp->getResults(),
+                 loops.front()->getResults().take_back(newInits.size()))) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 

>From 9c0d42678b1a2fe87abe269771860d3802f0b0df Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 20 Mar 2025 22:31:10 -0700
Subject: [PATCH 2/2] Address comments.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../SCF/Transforms/TileUsingInterface.cpp     |  67 ++----
 .../transform-tile-and-fuse-pack-unpack.mlir  |   4 +-
 .../tile-and-fuse-consumer.mlir               | 196 ++++--------------
 .../TestTilingInterfaceTransformOps.cpp       |  22 +-
 .../TestTilingInterfaceTransformOps.td        |  10 +-
 5 files changed, 81 insertions(+), 218 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 4fd10b0e30ab0..8e407cc1b348f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1846,11 +1846,9 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
   return failure();
 }
 
-/// Find the perfectly nested loops outside of given loop(included) sorted
-/// from outer to inner.
-///
-/// E.g.
-///
+/// Check that the loop is perfectly nested.
+/// The loops are expected to be ordered from outer most to inner most.
+/// For example:
 /// ```
 ///  %0 = scf.for()
 ///    %1 = scf.for()
@@ -1860,37 +1858,7 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
 ///      yield %2
 ///    yield %1
 /// ```
-///
-/// This function will return three perfectly nested loops: %0 + %1 + %2, when
-/// target inner loop is %2.
-static SmallVector<scf::ForOp>
-getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
-  SmallVector<scf::ForOp> nestLoops = {loop};
-  auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
-
-  // Check if it is the ForOp that yield the result of inner loop.
-  auto isForOpYieldResultOfInnerLoop =
-      [](scf::ForOp outerLoop) -> LogicalResult {
-    Block *body = outerLoop.getBody();
-    if (!llvm::hasSingleElement(body->without_terminator()))
-      return failure();
-    auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
-    auto innerForOp = dyn_cast<scf::ForOp>(body->front());
-    if (!innerForOp)
-      return failure();
-    // All of innerForOp results should be yielded.
-    return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
-  };
-
-  while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
-    nestLoops.push_back(outerLoop);
-    outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
-  }
-  // sorted from outer to inner
-  return {nestLoops.rbegin(), nestLoops.rend()};
-}
-
-/// Check that the loop is perfectly nested.
+/// Here loops should be [%0, %1].
 static bool
 isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
   assert(!loops.empty() && "unexpected empty loop nest");
@@ -1911,21 +1879,21 @@ isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
     }
 
     for (auto [outerBBArg, innerIterArg] :
-         llvm::zip(outerBBArgs, innerIterArgs)) {
+         llvm::zip_equal(outerBBArgs, innerIterArgs)) {
       if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
           innerIterArg != outerBBArg) {
         return false;
       }
     }
 
-    auto outerYields =
+    ValueRange outerYields =
         cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
-    auto innerResults = innerFor.getResults();
+    ValueRange innerResults = innerFor.getResults();
     if (outerYields.size() != innerResults.size()) {
       return false;
     }
     for (auto [outerYield, innerResult] :
-         llvm::zip(outerYields, innerResults)) {
+         llvm::zip_equal(outerYields, innerResults)) {
       if (!llvm::hasSingleElement(innerResult.getUses()) ||
           outerYield != innerResult) {
         return false;
@@ -1935,10 +1903,12 @@ isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
   return true;
 }
 
-/// Fetch the untiled consumer of a scf.for's result which is yielded by a
-/// tensor.insert_slice. This function makes the following assumptions :
-/// 1.  tensor.insert_slice has scf.yield as its only user.
-/// 2.  scf.for's corresponding result has only one use.
+/// Fetch the untiled consumer of the outermost scf.for's result which is
+/// yielded by a tensor.insert_slice from the innermost scf.for. This function
+/// makes the following assumptions :
+/// 1. tensor.insert_slice has scf.yield as its only user.
+/// 2. scf.for's corresponding result has only one use.
+/// 3. The `loops` passed in are perfectly nested `scf.for` operations.
 static FailureOr<OpOperand *>
 getUntiledConsumerFromSlice(RewriterBase &rewriter,
                             tensor::InsertSliceOp candidateSliceOp,
@@ -1952,6 +1922,7 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
         "expected slice to be within body of inner-most loop");
   }
 
+  // 2. Check that the loop is perfectly nested.
   if (!isPerfectlyNestedForLoops(loops)) {
     return rewriter.notifyMatchFailure(
         candidateSliceOp, "expected passed loops to be perfectly nested.");
@@ -1960,7 +1931,8 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
   if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
     return failure();
   Value sliceResult = candidateSliceOp.getResult();
-  // Step 1. Fetch the corresponding output.
+
+  // 3. Fetch the corresponding output.
   OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
   unsigned resultNumber = yieldOpOperand.getOperandNumber();
 
@@ -2007,10 +1979,7 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
 static FailureOr<OpOperand *>
 getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
                             MutableArrayRef<LoopLikeOpInterface> loops) {
-  if (loops.empty()) {
-    return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops");
-  }
-
+  assert(!loops.empty() && "unexpected empty loops");
   if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
     return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
   } else if (auto parallelInsertSlice =
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
index 5d4ae4f15d3fd..185fb9b358055 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -170,8 +170,8 @@ module {
       // Fuse the consumer operation into the tiled loop.
       %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
           : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
-      transform.test.fuse_consumer %slice_op
-        : (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op)
+      transform.test.fuse_consumer %slice_op in (%forall_op)
+        : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
       transform.yield
     }
   }
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 8ce05d94c4ad0..77e52946b830f 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -26,10 +26,12 @@ module {
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
     %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %yield
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %yield in (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -83,11 +85,13 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
     %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
         : (!transform.any_op)
         -> (!transform.any_op, !transform.any_op)
-    %a, %b = transform.test.fuse_consumer %first_slice_op
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -153,8 +157,10 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %yield
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %yield in (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -220,11 +226,13 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
     %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
         : (!transform.any_op)
         -> (!transform.any_op, !transform.any_op)
-    %a, %b = transform.test.fuse_consumer %first_slice_op
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -287,8 +295,10 @@ module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
         %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
         : (!transform.any_op) -> !transform.any_op
-        %a, %b = transform.test.fuse_consumer %slice_op
-        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+        %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+        : (!transform.any_op) -> !transform.any_op
+        %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
         transform.yield
     }
 }
@@ -348,8 +358,10 @@ module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
         %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
         : (!transform.any_op) -> !transform.any_op
-        %a, %b = transform.test.fuse_consumer %slice_op
-        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+        %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+        : (!transform.any_op) -> !transform.any_op
+        %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
         transform.yield
     }
 }
@@ -409,8 +421,10 @@ module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
         %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
         : (!transform.any_op) -> !transform.any_op
-        %a, %b = transform.test.fuse_consumer %slice_op
-        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+        %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+        : (!transform.any_op) -> !transform.any_op
+        %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
         transform.yield
     }
 }
@@ -437,143 +451,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-module {
-  func.func @fuse_add_consumer_into_nested_scf_for(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
-    %c0 = arith.constant 0 : index
-    %c64 = arith.constant 64 : index
-    %c256 = arith.constant 256 : index
-    %cst = arith.constant 0.000000e+00 : f32
-    %dest0 = tensor.empty() : tensor<256x256xf32>
-    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
-    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest1) -> (tensor<256x256xf32>) {
-      %2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %arg4) -> (tensor<256x256xf32>) {
-        %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
-        %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32>
-        %extracted_slice_3 = tensor.extract_slice %arg1[0, %arg5] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
-        %3 = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_1 : tensor<64x64xf32>) -> tensor<64x64xf32>
-        %insert_slice = tensor.insert_slice %3 into %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32>
-        scf.yield %insert_slice : tensor<256x256xf32>
-      }
-      scf.yield %2 : tensor<256x256xf32>
-    }
-    %4 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
-    return %4 : tensor<256x256xf32>
-  }
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %slice_op
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-//      CHECK: func.func @fuse_add_consumer_into_nested_scf_for(
-// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
-// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
-// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
-//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
-//      CHECK:   %[[dest1:.*]] = linalg.fill
-// CHECK-SAME:          outs(%[[dest0]] :
-//      CHECK:   %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
-// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[dest0]])
-// CHECK-SAME:   {
-//      CHECK:       %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
-// CHECK-SAME:         iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
-// CHECK-SAME:         {
-//      CHECK:            %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
-//      CHECK:            %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
-//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
-// CHECK-SAME:                  outs(%[[MAT_OUT_SLICE]] :
-//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
-// CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
-// CHECK-SAME:              outs(%[[ADD_OUT_SLICE]] :
-//      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
-//      CHECK:         }
-//      CHECK:         scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
-//      CHECK:   }
-//      CHECK:   return %[[LOOP_RESULT1]]#1 :
-
-// -----
-
-// This test case checks fusion of consumer even if the producer has multiple uses.
-// The multiple uses of the producer essentially means that besides the consumer
-// op in concern, the only other uses of the producer are allowed in :-
-// 1. scf.yield
-// 2. tensor.parallel_insert_slice
-
-module {
-  module {
-    func.func @fuse_consumer_for_multi_use_producer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
-      %c0 = arith.constant 0 : index
-      %c64 = arith.constant 64 : index
-      %c256 = arith.constant 256 : index
-      %cst = arith.constant 0.000000e+00 : f32
-      %0 = tensor.empty() : tensor<256x256xf32>
-      %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
-      %2:2 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %1, %arg5 = %arg2) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
-        %3 = scf.for %arg6 = %c0 to %c256 step %c64 iter_args(%arg7 = %arg4) -> (tensor<256x256xf32>) {
-          %extracted_slice = tensor.extract_slice %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
-          %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32>
-          %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg6] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
-          %5 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice : tensor<64x64xf32>) -> tensor<64x64xf32>
-          %inserted_slice = tensor.insert_slice %5 into %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32>
-          scf.yield %inserted_slice : tensor<256x256xf32>
-        }
-        %4 = linalg.add ins(%3, %arg5 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
-        scf.yield %3, %4 : tensor<256x256xf32>, tensor<256x256xf32>
-      }
-      return %2#0, %2#1 : tensor<256x256xf32>, tensor<256x256xf32>
-    }
-  }
-  module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-      %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-      %consumer, %fused_consumer = transform.test.fuse_consumer %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-      transform.yield
-    }
-  }
-}
-//      CHECK: func.func @fuse_consumer_for_multi_use_producer(
-// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
-// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
-// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
-//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
-//      CHECK:   %[[dest1:.*]] = linalg.fill
-// CHECK-SAME:          outs(%[[dest0]] :
-//      CHECK:   %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
-// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[ARG2]])
-// CHECK-SAME:   {
-//      CHECK:       %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
-// CHECK-SAME:         iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[dest0]])
-// CHECK-SAME:         {
-//      CHECK:            %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
-//      CHECK:            %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
-//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
-// CHECK-SAME:                  outs(%[[MAT_OUT_SLICE]] :
-//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG1]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
-// CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
-// CHECK-SAME:              outs(%[[ADD_OUT_SLICE]] :
-//      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
-//      CHECK:         }
-//      CHECK:         scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
-//      CHECK:   }
-//      CHECK:   return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :
-
-// -----
-
 module {
   func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
     %c0 = arith.constant 0 : index
@@ -599,8 +476,10 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 2
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -662,9 +541,10 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -733,8 +613,10 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 7380b766935ff..45d6ae3820159 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -169,10 +169,10 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
 /// Apply fusing of consumer transformation to all payload ops and store both
 /// the original consumer operation as well as the fused consumer operation.
 template <typename Range>
-static LogicalResult
-applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
-                  Range &&payloadOps, uint32_t numConsumerToFuse,
-                  TransformResults &transformResults) {
+static LogicalResult applyFuseConsumer(
+    RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
+    MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse,
+    TransformResults &transformResults) {
   SmallVector<Operation *> originalConsumerOps;
   SmallVector<Operation *> fusedConsumerOps;
 
@@ -181,7 +181,7 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
 
     while (numConsumerToFuse--) {
       FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
-          scf::tileAndFuseConsumerOfSlice(rewriter, target);
+          scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
 
       if (failed(fuseConsumerResults))
         return failure();
@@ -203,8 +203,17 @@ DiagnosedSilenceableFailure
 transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
                                      TransformResults &transformResults,
                                      TransformState &state) {
+  SmallVector<LoopLikeOpInterface> loops;
+  for (auto op : llvm::reverse(getLoops())) {
+    auto loopLikeOp =
+        dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(op).begin());
+    if (!loopLikeOp) {
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    loops.push_back(loopLikeOp);
+  }
   LogicalResult result = applyFuseConsumer(
-      rewriter, getOperation(), state.getPayloadOps(getTarget()),
+      rewriter, getOperation(), state.getPayloadOps(getTarget()), loops,
       getNumConsumerToFuse(), transformResults);
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
@@ -213,6 +222,7 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
 void transform::TestFuseConsumerOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getTargetMutable(), effects);
+  consumesHandle(getLoopsMutable(), effects);
   producesHandle(getOperation()->getOpResults(), effects);
   modifiesPayload(effects);
 }
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 34b075a5c17f9..98f7145c99cb1 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -58,14 +58,16 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
     using the options provided as attributes.
   }];
 
-  let arguments =
-    (ins TransformHandleTypeInterface:$target,
-        DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
+  let arguments = (ins 
+      TransformHandleTypeInterface:$target,
+      Variadic<TransformHandleTypeInterface>:$loops,
+      DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
   let results = (outs TransformHandleTypeInterface:$consumer,
                       TransformHandleTypeInterface:$fused_consumer);
 
   let assemblyFormat = [{
-    $target (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? 
+    $target `in` `(` $loops `)`
+    (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? 
     attr-dict `:` functional-type(operands, results)
   }];
 }



More information about the Mlir-commits mailing list