[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 16 18:45:32 PDT 2024


================
@@ -1418,52 +1560,93 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   if (isInsertSliceOp) {
     auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
     oldLoopOp = forOp;
-    llvm::append_range(newOuts, forOp.getInits());
-    oldLoopBody = forOp.getBody();
     initSize = forOp.getInits().size();
   } else {
     auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
     oldLoopOp = forallOp;
-    llvm::append_range(newOuts, forallOp.getOutputs());
-    oldLoopBody = forallOp.getBody();
     initSize = forallOp.getOutputs().size();
     rank = forallOp.getRank();
   }
 
-  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+  Operation *oldTopLevelLoop = oldLoopOp;
+  SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
+  if (isInsertSliceOp) {
+    oldNestedForOps =
+        getOuterNestLoopsWhile(cast<LoopLikeOpInterface>(oldTopLevelLoop),
+                               isForOpYieldResultOfInnerLoop);
+    oldTopLevelLoop = oldNestedForOps.front();
+  }
+  // 2.a Check assumption for loop and find suitable insertPoint that loop
+  // structure would be cloned right before.
+  Operation *insertPointBefore = nullptr;
+  if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp,
+                                    &insertPointBefore))) {
     return rewriter.notifyMatchFailure(
-        oldLoopOp, "containing loop op should either yield just one value or "
-                   "have the consumer op as its first user");
+        oldTopLevelLoop, "containing loop op does not satisfy the assumption "
+                         "and no suitable insertPoint is found");
   }
 
   OpBuilder::InsertionGuard g(rewriter);
 
-  // 2. Check consumer is not using scf loop's output as init.
+  // 2.b Check consumer is not using scf loop's output as init.
   auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  newOuts.append(dpsInits);
+  SmallVector<Value> newInitAppend = dpsInits;
 
   Location loc = oldLoopOp->getLoc();
 
   // 3. Create new scf loop op.
-  rewriter.setInsertionPoint(consumerOp);
+  rewriter.setInsertionPoint(insertPointBefore);
+
+  // 3.a Create new outer scf loops if necessary
----------------
Yun-Fly wrote:

Lets considering following example:

```
%0 = scf.for() {
  %1 = scf.for() {
     ...
     %2 = insert_slice
     yield %2
  }
  yield %1
}
...
%3 = consumer ins(%0)
```

1. What `tileAndFuseConsumerOfSlice` accepts is `candidateSliceOp`. thus it could only be fused within second level loop, `%1 = scf.for()`.
2. Then, based on previous design of implementation, we have to clone the whole loop structure right before `%3 = consumer ins(%0)`.
3. Before, `tileAndFuseConsumerOfSlice` only needs to clone inner most loop, like `%1 = scf.for()` here, because it does not expect nest loops. However, in this PR, not only inner most loop, but also all outer loops need cloned. That is why we need to create new `scf.for`s there.

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list