[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 8 01:49:36 PDT 2024


================
@@ -1289,28 +1352,108 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
   return getConsumerFromUses(resultingValue, containingOp->getBlock());
 }
 
-/// This utility currently checks whether the loop either :-
-/// 1. Yields exactly one result.
-/// 2. Has consumer op as its first user and other users to be in the same
-/// containing block as that of consumer op's. Currently we clone the loop op
-/// right before the consumer op in order to maintain a valid def-use chain.
-/// This utility thus helps ensuring that no invalid IR is formed due to the
-/// same.
+/// This utility currently checks whether the first userOp of loop is NOT before
+/// the last defineOp of consumer. Currently we clone the loop op right before
+/// a certain op in order to maintain a valid def-use chain. This utility thus
+/// helps ensuring that no invalid IR is formed due to the same. E.g.
+///
+/// ```
+/// %0 = scf.for() {
+///
+/// }
+/// ...
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ...
+/// %3 = consumerOp(%2)
+/// ```
+///
+/// If the `firstUserOfLoop`is before `lastDefOfConsumer`, then it would be
+/// invalid to clone the loop op right before the `firstUserOfLoop`:
+///
+/// ```
+/// %0:2 = scf.for() {
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ```
+///
+/// To address this issue, this utility would double-check there is no user of
+/// `firstUserOfLoop` before `lastDefOfConsumer`. If so, moving
+/// `firstUserOfLoop` after `lastDefOfConsumer`. Then, it turns out valid as
+/// follow:
+///
+/// ```
+/// %2 = lastDefOfConsumer
+/// %0:2 = scf.for() {
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ```
+///
+/// @param loopOp: loop operation
+/// @param consumerOp: consumer operation
+/// @param insertPointBefore: which operation we clone the looOp right before
 static LogicalResult checkAssumptionForLoop(Operation *loopOp,
-                                            Operation *consumerOp) {
-  // Check if the loop op yields one result.
-  if (loopOp->getNumResults() == 1)
-    return success();
-  // Check if the consumerOp is the first user of the loopOp and if other users
-  // are in the same containing block as that of consumer op's.
+                                            Operation *consumerOp,
+                                            Operation **insertPointBefore) {
   Block *parentBlock = consumerOp->getBlock();
+  // loopOp and consumerOp should stay in the same block.
+  if (loopOp->getBlock() != parentBlock)
+    return failure();
+
+  Operation *firstUserOfLoop = consumerOp, *lastDefOfConsumer = loopOp;
+  // Find the first user of loopOp
----------------
Yun-Fly wrote:

@Abhishek-Varma In some case where the consumer is not the first user of loop, it still has chance to clone loop right before the first user of loop rather than consumer.

The only concern is that we need to ensure the FIRST userOp of loop is NOT before the LAST defineOp of consumer. E.g.

```
%0 = scf.for() {
  %1:2 = scf.for() {
      %t0 = ...
      %t1 = tiled_producer
      yield %t0, %t1
  }
  insert_slice %1#0
  ...
  %2 = extract_slice
  tiled_consumer ins(%1#1, %2)
  ...
}
```

If we want to iteratively fuse `tiled_consumer` into `%1` loop, we could not clone loop right before it because ` insert_slice %1#0` is actually the first user of loop. Meanwhile, it is also valid to straightforwardly clone loop right before  insert_slice `%1#0` because it is ahead of the last define of `tiled_consumer` -- ` %2 = extract_slice`. Thus, we can firstly check if there is no user of `firstUserOfLoop`  before `lastDefOfConsumer`. If so, moving `firstUserOfLoop` after `lastDefOfConsumer`. The final IR appears like below:
```
%0 = scf.for() {
  ...
  %2 = extract_slice
  %1:3 = scf.for() {
    %t0 = ...
    %t1 = tiled_producer
    %t2 = tiled_consumer ins(%t1, %2) 
     ...
     yield %t0, %t1, %t2
  }
  insert_slice %1#0
  ..
```
we have talked about this in previous [thread](https://github.com/llvm/llvm-project/pull/88712#issuecomment-2104441680) before, please let me know your thought about this enhancement.

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list