[Mlir-commits] [mlir] [mlir][scf] Extend consumer fusion to multiple tilable users (PR #111955)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 28 00:47:33 PDT 2024


================
@@ -1699,28 +1702,131 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
   return getConsumerFromUses(resultingValue, containingOp->getBlock());
 }
 
-/// This utility currently checks whether the loop either :-
-/// 1. Yields exactly one result.
-/// 2. Has consumer op as its first user and other users to be in the same
-/// containing block as that of consumer op's. Currently we clone the loop op
-/// right before the consumer op in order to maintain a valid def-use chain.
-/// This utility thus helps ensuring that no invalid IR is formed due to the
-/// same.
-static LogicalResult checkAssumptionForLoop(Operation *loopOp,
-                                            Operation *consumerOp) {
-  // Check if the loop op yields one result.
-  if (loopOp->getNumResults() == 1)
-    return success();
-  // Check if the consumerOp is the first user of the loopOp and if other users
-  // are in the same containing block as that of consumer op's.
+/// This utility currently checks whether the first userOp of loop is NOT before
+/// the last defineOp of consumer. Currently we need to move the loop op right
+/// before a certain op in order to maintain a valid use-def chain. This utility
+/// thus helps ensuring that no invalid IR is formed. E.g.
+///
+/// ```
+/// %0 = scf.for() {
+///   ...
+/// }
+/// ...
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ...
+/// %3 = consumerOp(%2)
+/// ```
+///
+/// If the `firstUserOfLoop`is before `lastDefOfConsumer`, then it would be
+/// invalid to move the loop op right before the `firstUserOfLoop`, a.k.a.
+/// use-def chain violation:
+///
+/// ```
+/// %0:2 = scf.for() {
+///    // use before define error
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ```
+///
+/// To address this issue, this utility would try to move `lastDefOfConsumer`
+/// before `firstUserOfLoop` under intrusive mode. Then, it turns out valid as
+/// follow:
+///
+/// ```
+/// %2 = lastDefOfConsumer
+/// %0:2 = scf.for() {
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ```
+///
+/// @param loopOp: loop operation
+/// @param consumerOp: consumer operation
+/// @param firstUserOfLoop: the first user of loopOp, which op we move the looOp
+/// right before
+/// @param intrusive: if true, it allows to move computed slice w.r.t defineOp
+/// of operands of consumerOp. The default value is True. If explicit memory
+/// barrier is required, please turn it off.
+static LogicalResult checkAssumptionForLoop(RewriterBase &rewriter,
+                                            Operation *loopOp,
+                                            Operation *consumerOp,
+                                            Operation **firstUserOfLoop,
+                                            bool intrusive = true) {
   Block *parentBlock = consumerOp->getBlock();
+  // 1. Check if loopOp and consumerOp stay in the same block.
+  if (loopOp->getBlock() != parentBlock)
+    return failure();
+
+  *firstUserOfLoop = consumerOp;
+  // 2. Find the first user of loopOp.
   for (Operation *userOp : loopOp->getUsers()) {
     if (userOp == consumerOp)
       continue;
-    if (parentBlock != userOp->getBlock() ||
-        !consumerOp->isBeforeInBlock(userOp))
+    // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
+    // block with any other types of operation. Thus, just redirecting to its
+    // parent `InParallelOp`.
+    if (isa<tensor::ParallelInsertSliceOp>(userOp))
----------------
Yun-Fly wrote:

Lets take following case as example:
```
%1 = scf.for
%2 = op1 ins(%1, ...)
scf.forall.in_parallel {
      tensor.parallel_insert_slice %1
}
```
The users of `%1`(loopOp) are `op1` and `ParallelInsertSliceOp`, we expect all of users stay in the **same** block to find out which one is the first user of loop. For convenience, just redirecting `ParallelInsertSliceOp` to its parent `InParallelOp` and then compare it with `op1`. Otherwise, `isBeforeInBlock` method would fail at its assertion.

https://github.com/llvm/llvm-project/pull/111955


More information about the Mlir-commits mailing list