[Mlir-commits] [mlir] [mlir][scf] Extend consumer fusion to multiple tilable users (PR #111955)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 28 20:30:49 PDT 2024


https://github.com/Yun-Fly updated https://github.com/llvm/llvm-project/pull/111955

>From d43ff6e4a4017d6e287029c3192695d468c77c74 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Thu, 10 Oct 2024 22:51:49 -0700
Subject: [PATCH 1/4] extend consumer fusion to multiple tilable users

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 175 ++++++++++++++----
 .../tile-and-fuse-consumer.mlir               |  62 +++++++
 .../TestTilingInterfaceTransformOps.cpp       |  29 +--
 .../TestTilingInterfaceTransformOps.td        |   6 +-
 4 files changed, 219 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e2feb10b314540..a758db6c68cf81 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1585,26 +1585,27 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
 /// failure otherwise.
 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
                                                   Block *containingOpBlock) {
-  // Check that the value has exactly one use which isn't a scf.yield or a
-  // tensor.parallel_insert_slice op.
   OpOperand *operand = nullptr;
   for (OpOperand &opOperand : val.getUses()) {
     Operation *consumerOp = opOperand.getOwner();
-    if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
+    // Step 1. Check if the user is tilable.
+    if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
+      // TODO: We have to init result of consumer before scf.for, use
+      // DestinationStyleOpInterface to get result shape from init for now. Add
+      // support for other op such as op has InferTypeOpInterface.
       continue;
-    if (operand)
-      return failure();
-    // TODO: We have to init result of consumer before scf.for, use
-    //       DestinationStyleOpInterface to get result shape from init for now.
-    //       Add support for other op such as op has InferTypeOpInterface.
-    if (!isa<TilingInterface>(consumerOp) ||
-        !isa<DestinationStyleOpInterface>(consumerOp))
-      return failure();
-    if (containingOpBlock != consumerOp->getBlock())
-      return failure();
-    operand = &opOperand;
+    } else {
+      // Step 2. Check if user stay in the same block.
+      if (containingOpBlock != consumerOp->getBlock())
+        continue;
+      // Step 3. Check if user has succeeding user. Otherwise, it usually
+      // represents already tiled.
+      if (consumerOp->use_empty())
+        continue;
+      operand = &opOperand;
+      break;
+    }
   }
-
   if (operand)
     return operand;
   return failure();
@@ -1699,28 +1700,123 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
   return getConsumerFromUses(resultingValue, containingOp->getBlock());
 }
 
-/// This utility currently checks whether the loop either :-
-/// 1. Yields exactly one result.
-/// 2. Has consumer op as its first user and other users to be in the same
-/// containing block as that of consumer op's. Currently we clone the loop op
-/// right before the consumer op in order to maintain a valid def-use chain.
-/// This utility thus helps ensuring that no invalid IR is formed due to the
-/// same.
+/// This utility currently checks whether the first userOp of loop is NOT before
+/// the last defineOp of consumer. Currently we need to move the loop op right
+/// before a certain op in order to maintain a valid def-use chain. This utility
+/// thus helps ensuring that no invalid IR is formed. E.g.
+///
+/// ```
+/// %0 = scf.for() {
+///   ...
+/// }
+/// ...
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ...
+/// %3 = consumerOp(%2)
+/// ```
+///
+/// If the `firstUserOfLoop`is before `lastDefOfConsumer`, then it would be
+/// invalid to move the loop op right before the `firstUserOfLoop`:
+///
+/// ```
+/// %0:2 = scf.for() {
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ```
+///
+/// To address this issue, this utility would double-check there is no user of
+/// `firstUserOfLoop` before `lastDefOfConsumer`. If so, move `firstUserOfLoop`
+/// after `lastDefOfConsumer`. Then, it turns out valid as follow:
+///
+/// ```
+/// %2 = lastDefOfConsumer
+/// %0:2 = scf.for() {
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ```
+///
+/// Besides, `consumerOp` should not be the user of `firstUserOfLoop`.
+///
+/// @param loopOp: loop operation
+/// @param consumerOp: consumer operation
+/// @param toMoveLoopOpBefore: the operation we move the looOp right before
 static LogicalResult checkAssumptionForLoop(Operation *loopOp,
-                                            Operation *consumerOp) {
-  // Check if the loop op yields one result.
-  if (loopOp->getNumResults() == 1)
-    return success();
-  // Check if the consumerOp is the first user of the loopOp and if other users
-  // are in the same containing block as that of consumer op's.
+                                            Operation *consumerOp,
+                                            Operation **toMoveLoopOpBefore) {
   Block *parentBlock = consumerOp->getBlock();
-  for (Operation *userOp : loopOp->getUsers()) {
-    if (userOp == consumerOp)
-      continue;
-    if (parentBlock != userOp->getBlock() ||
-        !consumerOp->isBeforeInBlock(userOp))
-      return failure();
-  }
+  // loopOp and consumerOp should stay in the same block.
+  if (loopOp->getBlock() != parentBlock)
+    return failure();
+
+  *toMoveLoopOpBefore = nullptr;
+  do {
+    Operation *firstUserOfLoop = consumerOp, *lastDefOfConsumer = loopOp;
+    // Find the first user of loopOp
+    for (Operation *userOp : loopOp->getUsers()) {
+      if (userOp == consumerOp)
+        continue;
+      // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
+      // block with any other types of operation. Thus, just redirecting to its
+      // parent `InParallelOp`.
+      if (isa<tensor::ParallelInsertSliceOp>(userOp))
+        userOp = userOp->getParentOfType<scf::InParallelOp>();
+
+      if (parentBlock != userOp->getBlock())
+        return failure();
+
+      if (userOp->isBeforeInBlock(firstUserOfLoop))
+        firstUserOfLoop = userOp;
+    }
+
+    // Find the last define of consumer
+    for (Value operand : consumerOp->getOperands()) {
+      // If the operand is `BlockArgument`, auto skip.
+      if (isa<BlockArgument>(operand))
+        continue;
+      auto defineOp = operand.getDefiningOp();
+      if (!defineOp)
+        return failure();
+      // If defineOp is not in the same block with loopOp, it must dominate the
+      // loopOp as well. I.e.
+      // ```
+      //  %a = ...
+      //  {
+      //     %looOp = scf.for
+      //     %b = consumerOp ins(%loopOp, %a)
+      //   }
+      // ```
+      if (defineOp == loopOp || parentBlock != defineOp->getBlock())
+        continue;
+      if (lastDefOfConsumer->isBeforeInBlock(defineOp))
+        lastDefOfConsumer = defineOp;
+    }
+    if (firstUserOfLoop->isBeforeInBlock(lastDefOfConsumer)) {
+      // Try to move if possible
+      if (llvm::all_of(firstUserOfLoop->getUsers(),
+                       [&lastDefOfConsumer, &parentBlock](Operation *userOp) {
+                         return userOp->getBlock() == parentBlock &&
+                                lastDefOfConsumer->isBeforeInBlock(userOp);
+                       })) {
+        // Safely moving
+        firstUserOfLoop->moveAfter(lastDefOfConsumer);
+      } else {
+        return failure();
+      }
+    } else {
+      // Check consumerOp is not the user of firstUserOfLoop
+      if (firstUserOfLoop == lastDefOfConsumer)
+        return failure();
+      // Set InsertPoint
+      *toMoveLoopOpBefore = firstUserOfLoop;
+    }
+  } while (!(*toMoveLoopOpBefore));
+
   return success();
 }
 
@@ -1787,7 +1883,10 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   LoopLikeOpInterface outerMostLoop = nestedLoops.front();
 
-  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
+  // Find suitable insertPointOp to move the whole loop structure later.
+  Operation *toMoveLoopOpBefore = nullptr;
+  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp,
+                                    &toMoveLoopOpBefore))) {
     return rewriter.notifyMatchFailure(
         outerMostLoop,
         "containing loop op should either yield just one value or "
@@ -1812,9 +1911,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   Location loc = outerMostLoop->getLoc();
 
-  // 3. Move the whole loop structure right before consumer Op, the dominance
+  // 3. Move the whole loop structure right before insertPoint, the dominance
   // should be already ensured by `checkAssumptionForLoop`.
-  rewriter.moveOpBefore(outerMostLoop, consumerOp);
+  rewriter.moveOpBefore(outerMostLoop, toMoveLoopOpBefore);
 
   // 4. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index f5f703d95e2d5b..af836d18e8c028 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -508,3 +508,65 @@ module {
 //      CHECK:         scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
 //      CHECK:   }
 //      CHECK:   return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :
+
+// -----
+
+module {
+  func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c256 = arith.constant 256 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+        %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+        %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+        scf.yield %insert_slice : tensor<256x256xf32>
+    }
+    %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_add_multiple_tilable_consumers(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
+// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) 
+// CHECK-SAME:   {
+//      CHECK:          %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME:                ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
+// CHECK-SAME:                outs(%[[ADD_OUT_SLICE]] :
+//      CHECK:          %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[TILED_EXP_OUT:.*]] = linalg.exp
+// CHECK-SAME:                ins(%[[TILED_ADD_OUT]] :
+// CHECK-SAME:                outs(%[[EXP_OUT_SLICE]] :
+//      CHECK:          %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[TILED_MUL_OUT:.*]] = linalg.mul
+// CHECK-SAME:                ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
+// CHECK-SAME:                outs(%[[MUL_OUT_SLICE]] :
+//      CHECK:          %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
+//      CHECK:   }
+//      CHECK:   return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index b6da47977cb4cf..5e903e378daf82 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -171,24 +171,27 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
 template <typename Range>
 static LogicalResult
 applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
-                  Range &&payloadOps, TransformResults &transformResults) {
+                  Range &&payloadOps, uint32_t numConsumerToFuse,
+                  TransformResults &transformResults) {
   SmallVector<Operation *> originalConsumerOps;
   SmallVector<Operation *> fusedConsumerOps;
 
   for (Operation *target : payloadOps) {
     rewriter.setInsertionPoint(target);
 
-    FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
-        scf::tileAndFuseConsumerOfSlice(rewriter, target);
+    while (numConsumerToFuse--) {
+      FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+          scf::tileAndFuseConsumerOfSlice(rewriter, target);
 
-    if (failed(fuseConsumerResults))
-      return failure();
+      if (failed(fuseConsumerResults))
+        return failure();
 
-    // Report back the relevant handles to the transform op.
-    originalConsumerOps.push_back(
-        fuseConsumerResults->origConsumerOperand->getOwner());
-    fusedConsumerOps.push_back(
-        fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+      // Report back the relevant handles to the transform op.
+      originalConsumerOps.push_back(
+          fuseConsumerResults->origConsumerOperand->getOwner());
+      fusedConsumerOps.push_back(
+          fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+    }
   }
 
   transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
@@ -200,9 +203,9 @@ DiagnosedSilenceableFailure
 transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
                                      TransformResults &transformResults,
                                      TransformState &state) {
-  LogicalResult result =
-      applyFuseConsumer(rewriter, getOperation(),
-                        state.getPayloadOps(getTarget()), transformResults);
+  LogicalResult result = applyFuseConsumer(
+      rewriter, getOperation(), state.getPayloadOps(getTarget()),
+      getNumConsumerToFuse(), transformResults);
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index d55d746bd6aa90..34b075a5c17f9e 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -59,12 +59,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
   }];
 
   let arguments =
-    (ins TransformHandleTypeInterface:$target);
+    (ins TransformHandleTypeInterface:$target,
+        DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
   let results = (outs TransformHandleTypeInterface:$consumer,
                       TransformHandleTypeInterface:$fused_consumer);
 
   let assemblyFormat = [{
-    $target attr-dict `:` functional-type(operands, results)
+    $target (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? 
+    attr-dict `:` functional-type(operands, results)
   }];
 }
 

>From 02182dc8319b94700d00222c8e94fba814ba5e86 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Tue, 15 Oct 2024 01:36:01 -0700
Subject: [PATCH 2/4] fix comment by backward slice

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 154 ++++++++++--------
 1 file changed, 82 insertions(+), 72 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a758db6c68cf81..b5b2faae3736d2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -12,6 +12,8 @@
 
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -1702,7 +1704,7 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
 
 /// This utility currently checks whether the first userOp of loop is NOT before
 /// the last defineOp of consumer. Currently we need to move the loop op right
-/// before a certain op in order to maintain a valid def-use chain. This utility
+/// before a certain op in order to maintain a valid use-def chain. This utility
 /// thus helps ensuring that no invalid IR is formed. E.g.
 ///
 /// ```
@@ -1718,10 +1720,12 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
 /// ```
 ///
 /// If the `firstUserOfLoop`is before `lastDefOfConsumer`, then it would be
-/// invalid to move the loop op right before the `firstUserOfLoop`:
+/// invalid to move the loop op right before the `firstUserOfLoop`, a.k.a.
+/// use-def chain violation:
 ///
 /// ```
 /// %0:2 = scf.for() {
+///    // use before define error
 ///    %3 = tiledConsumerOp(%2)
 /// }
 /// %1 = firstUserOfLoop(%0)
@@ -1729,9 +1733,9 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
 /// %2 = lastDefOfConsumer
 /// ```
 ///
-/// To address this issue, this utility would double-check there is no user of
-/// `firstUserOfLoop` before `lastDefOfConsumer`. If so, move `firstUserOfLoop`
-/// after `lastDefOfConsumer`. Then, it turns out valid as follow:
+/// To address this issue, this utility would try to move `lastDefOfConsumer`
+/// before `firstUserOfLoop` under intrusive mode. Then, it turns out valid as
+/// follow:
 ///
 /// ```
 /// %2 = lastDefOfConsumer
@@ -1741,81 +1745,87 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
 /// %1 = firstUserOfLoop(%0)
 /// ```
 ///
-/// Besides, `consumerOp` should not be the user of `firstUserOfLoop`.
-///
 /// @param loopOp: loop operation
 /// @param consumerOp: consumer operation
-/// @param toMoveLoopOpBefore: the operation we move the looOp right before
-static LogicalResult checkAssumptionForLoop(Operation *loopOp,
+/// @param firstUserOfLoop: the first user of loopOp, which op we move the looOp
+/// right before
+/// @param intrusive: if true, it allows to move computed slice w.r.t defineOp
+/// of operands of consumerOp. The default value is True. If explicit memory
+/// barrier is required, please turn it off.
+static LogicalResult checkAssumptionForLoop(RewriterBase &rewriter,
+                                            Operation *loopOp,
                                             Operation *consumerOp,
-                                            Operation **toMoveLoopOpBefore) {
+                                            Operation **firstUserOfLoop,
+                                            bool intrusive = true) {
   Block *parentBlock = consumerOp->getBlock();
-  // loopOp and consumerOp should stay in the same block.
+  // 1. Check if loopOp and consumerOp stay in the same block.
   if (loopOp->getBlock() != parentBlock)
     return failure();
 
-  *toMoveLoopOpBefore = nullptr;
-  do {
-    Operation *firstUserOfLoop = consumerOp, *lastDefOfConsumer = loopOp;
-    // Find the first user of loopOp
-    for (Operation *userOp : loopOp->getUsers()) {
-      if (userOp == consumerOp)
-        continue;
-      // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
-      // block with any other types of operation. Thus, just redirecting to its
-      // parent `InParallelOp`.
-      if (isa<tensor::ParallelInsertSliceOp>(userOp))
-        userOp = userOp->getParentOfType<scf::InParallelOp>();
+  *firstUserOfLoop = consumerOp;
+  // 2. Find the first user of loopOp.
+  for (Operation *userOp : loopOp->getUsers()) {
+    if (userOp == consumerOp)
+      continue;
+    // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
+    // block with any other types of operation. Thus, just redirecting to its
+    // parent `InParallelOp`.
+    if (isa<tensor::ParallelInsertSliceOp>(userOp))
+      userOp = userOp->getParentOfType<scf::InParallelOp>();
 
-      if (parentBlock != userOp->getBlock())
-        return failure();
+    if (parentBlock != userOp->getBlock())
+      return failure();
 
-      if (userOp->isBeforeInBlock(firstUserOfLoop))
-        firstUserOfLoop = userOp;
-    }
+    if (userOp->isBeforeInBlock(*firstUserOfLoop))
+      *firstUserOfLoop = userOp;
+  }
 
-    // Find the last define of consumer
-    for (Value operand : consumerOp->getOperands()) {
-      // If the operand is `BlockArgument`, auto skip.
-      if (isa<BlockArgument>(operand))
-        continue;
-      auto defineOp = operand.getDefiningOp();
-      if (!defineOp)
-        return failure();
-      // If defineOp is not in the same block with loopOp, it must dominate the
-      // loopOp as well. I.e.
-      // ```
-      //  %a = ...
-      //  {
-      //     %looOp = scf.for
-      //     %b = consumerOp ins(%loopOp, %a)
-      //   }
-      // ```
-      if (defineOp == loopOp || parentBlock != defineOp->getBlock())
-        continue;
-      if (lastDefOfConsumer->isBeforeInBlock(defineOp))
-        lastDefOfConsumer = defineOp;
-    }
-    if (firstUserOfLoop->isBeforeInBlock(lastDefOfConsumer)) {
-      // Try to move if possible
-      if (llvm::all_of(firstUserOfLoop->getUsers(),
-                       [&lastDefOfConsumer, &parentBlock](Operation *userOp) {
-                         return userOp->getBlock() == parentBlock &&
-                                lastDefOfConsumer->isBeforeInBlock(userOp);
-                       })) {
-        // Safely moving
-        firstUserOfLoop->moveAfter(lastDefOfConsumer);
-      } else {
-        return failure();
+  // 3. Find backward slice of defOfConsumer.
+  BackwardSliceOptions options;
+  DominanceInfo dominanceInfo;
+  options.inclusive = true;
+  options.omitBlockArguments = true;
+
+  for (auto operand : consumerOp->getOperands()) {
+    llvm::SetVector<Operation *> slice;
+    bool includeLoopOp = false;
+    options.filter = [&](Operation *op) {
+      if (op == loopOp) {
+        includeLoopOp = true;
+        return false;
+      }
+      // Cut off the slice to not include any operation that already dominates
+      // firstUserOfLoop.
+      return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
+    };
+    getBackwardSlice(operand, &slice, options);
+    if (!slice.empty()) {
+      if (includeLoopOp) {
+        // If consumerOp has one producer, which is also the user of loopOp.
+        // E.g.
+        // ```
+        //  %0 = %loopOp
+        //  %1 = consumerOp1 ins(%0)
+        //  %2 = consumerOp2 ins(%0, %1)
+        // ```
+        // We can not fuse consumerOp2 into loopOp due to UD chain, unless
+        // consumerOp1 has already been fused into loopOp before.
+        return rewriter.notifyMatchFailure(
+            consumerOp, "could not fuse consumer due to inevitable use-def "
+                        "chain violation");
+      }
+      if (!intrusive) {
+        // Please turn on intrusive mode, otherwise just bail out.
+        return rewriter.notifyMatchFailure(consumerOp,
+                                           "intrusive mode is not allowed");
+      }
+      mlir::topologicalSort(slice);
+      // 4. Move all computed slice before firstUserOfLoop.
+      for (auto op : slice) {
+        rewriter.moveOpBefore(op, *firstUserOfLoop);
       }
-    } else {
-      // Check consumerOp is not the user of firstUserOfLoop
-      if (firstUserOfLoop == lastDefOfConsumer)
-        return failure();
-      // Set InsertPoint
-      *toMoveLoopOpBefore = firstUserOfLoop;
     }
-  } while (!(*toMoveLoopOpBefore));
+  }
 
   return success();
 }
@@ -1884,9 +1894,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   LoopLikeOpInterface outerMostLoop = nestedLoops.front();
 
   // Find suitable insertPointOp to move the whole loop structure later.
-  Operation *toMoveLoopOpBefore = nullptr;
-  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp,
-                                    &toMoveLoopOpBefore))) {
+  Operation *firstUserOfLoop = nullptr;
+  if (failed(checkAssumptionForLoop(rewriter, outerMostLoop, consumerOp,
+                                    &firstUserOfLoop))) {
     return rewriter.notifyMatchFailure(
         outerMostLoop,
         "containing loop op should either yield just one value or "
@@ -1913,7 +1923,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   // 3. Move the whole loop structure right before insertPoint, the dominance
   // should be already ensured by `checkAssumptionForLoop`.
-  rewriter.moveOpBefore(outerMostLoop, toMoveLoopOpBefore);
+  rewriter.moveOpBefore(outerMostLoop, firstUserOfLoop);
 
   // 4. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the

>From 1065a020289eae93f3b8e87202aa82244da9b81f Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Tue, 15 Oct 2024 02:19:46 -0700
Subject: [PATCH 3/4] prepare for next consumer

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 38 ++++++++++++++-----
 1 file changed, 29 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index b5b2faae3736d2..92364ce19e39ad 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1800,16 +1800,16 @@ static LogicalResult checkAssumptionForLoop(RewriterBase &rewriter,
     };
     getBackwardSlice(operand, &slice, options);
     if (!slice.empty()) {
+      // If consumerOp has one producer, which is also the user of loopOp.
+      // E.g.
+      // ```
+      //  %0 = %loopOp
+      //  %1 = consumerOp1 ins(%0)
+      //  %2 = consumerOp2 ins(%0, %1)
+      // ```
+      // We can not fuse consumerOp2 into loopOp due to UD chain, unless
+      // consumerOp1 has already been fused into loopOp before.
       if (includeLoopOp) {
-        // If consumerOp has one producer, which is also the user of loopOp.
-        // E.g.
-        // ```
-        //  %0 = %loopOp
-        //  %1 = consumerOp1 ins(%0)
-        //  %2 = consumerOp2 ins(%0, %1)
-        // ```
-        // We can not fuse consumerOp2 into loopOp due to UD chain, unless
-        // consumerOp1 has already been fused into loopOp before.
         return rewriter.notifyMatchFailure(
             consumerOp, "could not fuse consumer due to inevitable use-def "
                         "chain violation");
@@ -1843,6 +1843,24 @@ static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
   }
 }
 
+/// A utility to move the given operand to the end of use list.
+static void moveOperandToEndOfUseList(OpOperand *operand) {
+  Value::use_range uses = operand->get().getUses();
+  size_t numberUses = std::distance(uses.begin(), uses.end());
+  if (numberUses == 1)
+    return;
+  auto iter = llvm::find(uses, *operand);
+  if (iter == uses.end())
+    return;
+  unsigned index = std::distance(uses.begin(), iter);
+  SmallVector<unsigned> indices =
+      llvm::to_vector(llvm::seq<unsigned>(numberUses));
+  indices.push_back(indices[index]);
+  indices.erase(indices.begin() + index);
+  operand->get().shuffleUseList(indices);
+  return;
+}
+
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1897,6 +1915,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   Operation *firstUserOfLoop = nullptr;
   if (failed(checkAssumptionForLoop(rewriter, outerMostLoop, consumerOp,
                                     &firstUserOfLoop))) {
+    // Prepare for next consumer.
+    moveOperandToEndOfUseList(consumerOpOperand);
     return rewriter.notifyMatchFailure(
         outerMostLoop,
         "containing loop op should either yield just one value or "

>From cf287cb94917f5ad8cfa881b38e17e7cb028c185 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Tue, 29 Oct 2024 11:13:28 +0800
Subject: [PATCH 4/4] split move and check method

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 360 +++++++++---------
 1 file changed, 173 insertions(+), 187 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 92364ce19e39ad..02e58141bdc303 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1582,12 +1582,131 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
   return success();
 }
 
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
-                                                  Block *containingOpBlock) {
-  OpOperand *operand = nullptr;
+/// An utility to get the first user of the given loopOp. If any of user stay in
+/// different block of loopOp, return failure.
+static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
+  if (!isa<LoopLikeOpInterface>(loopOp))
+    return failure();
+  Operation *firstUserOfLoop = nullptr;
+  for (Operation *userOp : loopOp->getUsers()) {
+    // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
+    // block with any other types of operation. Thus, just redirecting to its
+    // parent `InParallelOp`. E.g.
+    //
+    // ```
+    // %1 = scf.for {
+    //   ...
+    // }
+    // %2 = consumerOp ins(%1, ...)
+    // scf.forall.in_parallel {
+    //    tensor.parallel_insert_slice %1
+    // }
+    // ```
+    // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
+    // same block with `consumerOp`.
+    if (isa<tensor::ParallelInsertSliceOp>(userOp))
+      userOp = userOp->getParentOfType<scf::InParallelOp>();
+
+    if (loopOp->getBlock() != userOp->getBlock())
+      return failure();
+
+    if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
+      firstUserOfLoop = userOp;
+  }
+  return firstUserOfLoop;
+}
+
+/// This utility currently checks whether the first userOp of loop is NOT before
+/// the last defineOp of consumer operand. Because that we need to move the
+/// whole loop structure right before the `firstUserOfLoop`. This utility thus
+/// helps ensuring that no invalid IR is formed, i.e. no backward slice of
+/// consumerOp is dominated by the `firstUserOfLoop`. Saying that:
+///
+/// ```
+/// %0 = scf.for() {
+///   ...
+/// }
+/// ...
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumerOperand
+/// ...
+/// %3 = consumerOp(%2)
+/// ```
+///
+/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
+/// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
+/// use-def chain violation:
+///
+/// ```
+/// %0:2 = scf.for() {
+///    // use before define error
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumerOperand
+/// ```
+///
+/// @param loopOp: loop operation
+/// @param consumerOp: consumer operation
+/// @param reorderOperations: the flag controls whether to reorder the backward
+/// slice w.r.t. the defineOp of `consumerOp` operands.
+/// @return: computed backward slice of consumerOp, but excluding those already
+/// dominates `firstUserOfLoop`.
+static FailureOr<llvm::SetVector<Operation *>>
+checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
+                       bool reorderOperations) {
+  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
+  if (failed(firstUserOfLoop))
+    return failure();
+
+  BackwardSliceOptions options;
+  DominanceInfo dominanceInfo;
+  options.inclusive = true;
+  options.omitBlockArguments = true;
+  bool includeLoopOp = false;
+  options.filter = [&](Operation *op) {
+    if (op == loopOp) {
+      includeLoopOp = true;
+      return false;
+    }
+    // Cut off the slice to not include any operation that already dominates
+    // firstUserOfLoop.
+    return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
+  };
+  llvm::SetVector<Operation *> slice;
+  for (auto operand : consumerOp->getOperands()) {
+    getBackwardSlice(operand, &slice, options);
+  }
+
+  if (!slice.empty()) {
+    // If consumerOp has one producer, which is also the user of loopOp.
+    // E.g.
+    // ```
+    //  %0 = %loopOp
+    //  %1 = consumerOp1 ins(%0)
+    //  %2 = consumerOp2 ins(%0, %1)
+    // ```
+    // We can not fuse consumerOp2 into loopOp due to UD chain, unless
+    // consumerOp1 has already been fused into loopOp before.
+    if (includeLoopOp || !reorderOperations)
+      return failure();
+  }
+
+  return slice;
+}
+
+/// Fetches the OpOperand of the first valid user (and use) of the value `val`
+/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
+/// Returns failure otherwise.
+static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
+                                                      Operation *loopOp,
+                                                      unsigned resultNumber) {
+  if (!isa<LoopLikeOpInterface>(loopOp))
+    return failure();
+  Value val = loopOp->getResult(resultNumber);
+  Block *loopBlock = loopOp->getBlock();
   for (OpOperand &opOperand : val.getUses()) {
     Operation *consumerOp = opOperand.getOwner();
     // Step 1. Check if the user is tilable.
@@ -1596,20 +1715,30 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
       // DestinationStyleOpInterface to get result shape from init for now. Add
       // support for other op such as op has InferTypeOpInterface.
       continue;
-    } else {
-      // Step 2. Check if user stay in the same block.
-      if (containingOpBlock != consumerOp->getBlock())
-        continue;
-      // Step 3. Check if user has succeeding user. Otherwise, it usually
-      // represents already tiled.
-      if (consumerOp->use_empty())
-        continue;
-      operand = &opOperand;
-      break;
     }
+    // Step 2. Check if user stay in the same block.
+    if (loopBlock != consumerOp->getBlock())
+      continue;
+    // Step 3. Check if user has succeeding user. Otherwise, it usually
+    // represents already tiled.
+    if (consumerOp->use_empty())
+      continue;
+    // Step 4. Check assumption for loop with `reorderOperations` enabled.
+    FailureOr<llvm::SetVector<Operation *>> slice =
+        checkAssumptionForLoop(loopOp, consumerOp, true);
+    if (failed(slice))
+      continue;
+    // Step 5. If backward sice is not empty, move them before firstUserOfLoop.
+    if (!slice->empty()) {
+      mlir::topologicalSort(*slice);
+      FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
+      assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
+      for (auto op : *slice) {
+        rewriter.moveOpBefore(op, *firstUserOfLoop);
+      }
+    }
+    return &opOperand;
   }
-  if (operand)
-    return operand;
   return failure();
 }
 
@@ -1662,7 +1791,8 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
 /// 1.  tensor.insert_slice has scf.yield as its only user.
 /// 2.  scf.for's corresponding result has only one use.
 static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+getUntiledConsumerFromSlice(RewriterBase &rewriter,
+                            tensor::InsertSliceOp candidateSliceOp) {
   if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
     return failure();
   Value sliceResult = candidateSliceOp.getResult();
@@ -1675,15 +1805,15 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
   if (!forOp)
     return failure();
   scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
-  Value resultingValue = topLevelForOp->getResult(resultNumber);
 
-  return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
+  return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
 }
 
 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
 /// by a tensor.parallel_insert_slice.
 static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+getUntiledConsumerFromSlice(RewriterBase &rewriter,
+                            tensor::ParallelInsertSliceOp candidateSliceOp) {
   // Step 1. Fetch the corresponding output
   Value sliceDest = candidateSliceOp.getDest();
   auto iterArg = dyn_cast<BlockArgument>(sliceDest);
@@ -1696,171 +1826,27 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
   auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
   if (!forallOp)
     return failure();
-  Value resultingValue =
-      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  unsigned resultNumber =
+      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
+          .getResultNumber();
 
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
-/// This utility currently checks whether the first userOp of loop is NOT before
-/// the last defineOp of consumer. Currently we need to move the loop op right
-/// before a certain op in order to maintain a valid use-def chain. This utility
-/// thus helps ensuring that no invalid IR is formed. E.g.
-///
-/// ```
-/// %0 = scf.for() {
-///   ...
-/// }
-/// ...
-/// %1 = firstUserOfLoop(%0)
-/// ...
-/// %2 = lastDefOfConsumer
-/// ...
-/// %3 = consumerOp(%2)
-/// ```
-///
-/// If the `firstUserOfLoop`is before `lastDefOfConsumer`, then it would be
-/// invalid to move the loop op right before the `firstUserOfLoop`, a.k.a.
-/// use-def chain violation:
-///
-/// ```
-/// %0:2 = scf.for() {
-///    // use before define error
-///    %3 = tiledConsumerOp(%2)
-/// }
-/// %1 = firstUserOfLoop(%0)
-/// ...
-/// %2 = lastDefOfConsumer
-/// ```
-///
-/// To address this issue, this utility would try to move `lastDefOfConsumer`
-/// before `firstUserOfLoop` under intrusive mode. Then, it turns out valid as
-/// follow:
-///
-/// ```
-/// %2 = lastDefOfConsumer
-/// %0:2 = scf.for() {
-///    %3 = tiledConsumerOp(%2)
-/// }
-/// %1 = firstUserOfLoop(%0)
-/// ```
-///
-/// @param loopOp: loop operation
-/// @param consumerOp: consumer operation
-/// @param firstUserOfLoop: the first user of loopOp, which op we move the looOp
-/// right before
-/// @param intrusive: if true, it allows to move computed slice w.r.t defineOp
-/// of operands of consumerOp. The default value is True. If explicit memory
-/// barrier is required, please turn it off.
-static LogicalResult checkAssumptionForLoop(RewriterBase &rewriter,
-                                            Operation *loopOp,
-                                            Operation *consumerOp,
-                                            Operation **firstUserOfLoop,
-                                            bool intrusive = true) {
-  Block *parentBlock = consumerOp->getBlock();
-  // 1. Check if loopOp and consumerOp stay in the same block.
-  if (loopOp->getBlock() != parentBlock)
-    return failure();
-
-  *firstUserOfLoop = consumerOp;
-  // 2. Find the first user of loopOp.
-  for (Operation *userOp : loopOp->getUsers()) {
-    if (userOp == consumerOp)
-      continue;
-    // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
-    // block with any other types of operation. Thus, just redirecting to its
-    // parent `InParallelOp`.
-    if (isa<tensor::ParallelInsertSliceOp>(userOp))
-      userOp = userOp->getParentOfType<scf::InParallelOp>();
-
-    if (parentBlock != userOp->getBlock())
-      return failure();
-
-    if (userOp->isBeforeInBlock(*firstUserOfLoop))
-      *firstUserOfLoop = userOp;
-  }
-
-  // 3. Find backward slice of defOfConsumer.
-  BackwardSliceOptions options;
-  DominanceInfo dominanceInfo;
-  options.inclusive = true;
-  options.omitBlockArguments = true;
-
-  for (auto operand : consumerOp->getOperands()) {
-    llvm::SetVector<Operation *> slice;
-    bool includeLoopOp = false;
-    options.filter = [&](Operation *op) {
-      if (op == loopOp) {
-        includeLoopOp = true;
-        return false;
-      }
-      // Cut off the slice to not include any operation that already dominates
-      // firstUserOfLoop.
-      return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
-    };
-    getBackwardSlice(operand, &slice, options);
-    if (!slice.empty()) {
-      // If consumerOp has one producer, which is also the user of loopOp.
-      // E.g.
-      // ```
-      //  %0 = %loopOp
-      //  %1 = consumerOp1 ins(%0)
-      //  %2 = consumerOp2 ins(%0, %1)
-      // ```
-      // We can not fuse consumerOp2 into loopOp due to UD chain, unless
-      // consumerOp1 has already been fused into loopOp before.
-      if (includeLoopOp) {
-        return rewriter.notifyMatchFailure(
-            consumerOp, "could not fuse consumer due to inevitable use-def "
-                        "chain violation");
-      }
-      if (!intrusive) {
-        // Please turn on intrusive mode, otherwise just bail out.
-        return rewriter.notifyMatchFailure(consumerOp,
-                                           "intrusive mode is not allowed");
-      }
-      mlir::topologicalSort(slice);
-      // 4. Move all computed slice before firstUserOfLoop.
-      for (auto op : slice) {
-        rewriter.moveOpBefore(op, *firstUserOfLoop);
-      }
-    }
-  }
-
-  return success();
+  return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
 }
 
 /// A utility to fetch an untiled consumer of
 /// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
   if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(insertSlice);
+    return getUntiledConsumerFromSlice(rewriter, insertSlice);
   } else if (auto parallelInsertSlice =
                  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(parallelInsertSlice);
+    return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
   } else {
     return failure();
   }
 }
 
-/// A utility to move the given operand to the end of use list.
-static void moveOperandToEndOfUseList(OpOperand *operand) {
-  Value::use_range uses = operand->get().getUses();
-  size_t numberUses = std::distance(uses.begin(), uses.end());
-  if (numberUses == 1)
-    return;
-  auto iter = llvm::find(uses, *operand);
-  if (iter == uses.end())
-    return;
-  unsigned index = std::distance(uses.begin(), iter);
-  SmallVector<unsigned> indices =
-      llvm::to_vector(llvm::seq<unsigned>(numberUses));
-  indices.push_back(indices[index]);
-  indices.erase(indices.begin() + index);
-  operand->get().shuffleUseList(indices);
-  return;
-}
-
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1875,7 +1861,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   // 1. Get the consumer of scf.for for the result yielded by
   // tensor.insert_slice/parallel_insert_slice.
   FailureOr<OpOperand *> maybeConsumerOpOperand =
-      getUntiledConsumerFromSlice(candidateSliceOp);
+      getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
   if (failed(maybeConsumerOpOperand)) {
     return rewriter.notifyMatchFailure(candidateSliceOp,
                                        "could not fetch consumer to fuse");
@@ -1911,16 +1897,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   LoopLikeOpInterface outerMostLoop = nestedLoops.front();
 
-  // Find suitable insertPointOp to move the whole loop structure later.
-  Operation *firstUserOfLoop = nullptr;
-  if (failed(checkAssumptionForLoop(rewriter, outerMostLoop, consumerOp,
-                                    &firstUserOfLoop))) {
-    // Prepare for next consumer.
-    moveOperandToEndOfUseList(consumerOpOperand);
+  // Check assumption for loop with `reorderOperations` disabled.
+  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
     return rewriter.notifyMatchFailure(
-        outerMostLoop,
-        "containing loop op should either yield just one value or "
-        "have the consumer op as its first user");
+        outerMostLoop, "the first user of loop should not dominate any define "
+                       "of consumer operand(s)");
   }
 
   OpBuilder::InsertionGuard g(rewriter);
@@ -1941,9 +1922,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   Location loc = outerMostLoop->getLoc();
 
-  // 3. Move the whole loop structure right before insertPoint, the dominance
-  // should be already ensured by `checkAssumptionForLoop`.
-  rewriter.moveOpBefore(outerMostLoop, firstUserOfLoop);
+  // 3. Move the whole loop structure right before firstUserOfLoop, the
+  // dominance should be already ensured by `checkAssumptionForLoop`.
+  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
+  if (failed(firstUserOfLoop)) {
+    return rewriter.notifyMatchFailure(
+        outerMostLoop, "could not find the first user of outer most loop");
+  }
+  rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
 
   // 4. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the



More information about the Mlir-commits mailing list