[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 16 19:33:07 PDT 2024


================
@@ -1577,28 +1760,159 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         newForallOp.getBody()->getArguments().drop_front(rank + initSize));
   }
 
-  // 12. Replace the result of scf loop and consumer op with new loop's results.
+  // 12.  Restore outer loops from inner to outer
+  if (isNestedForOps) {
+    newNestedForOps.push_back(cast<scf::ForOp>(newLoopOp));
+    for (auto [outerLoop, innerLoop] :
+         llvm::zip_equal(MutableArrayRef(newNestedForOps).drop_back(),
+                         MutableArrayRef(newNestedForOps).drop_front())) {
+      auto forOp = cast<scf::ForOp>(outerLoop);
+      auto outerLoopYield =
+          cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+      SmallVector<Value> newYields =
+          llvm::to_vector(outerLoopYield.getOperands());
+      ValueRange additionalYields =
+          innerLoop->getResults().take_back(newInitAppend.size());
+      newYields.append(additionalYields.begin(), additionalYields.end());
+      rewriter.setInsertionPoint(outerLoopYield);
+      rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
+    }
+  }
+
+  // 13. Replace the result of scf loop and consumer op with new loop's results.
   for (auto &&[oldResult, newResult] :
        llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
+  Operation *newTopLevelLoop =
+      isNestedForOps ? newNestedForOps.front() : newLoopOp;
   for (auto &&[oldResult, newResult] :
        llvm::zip(consumerOp->getResults(),
-                 newLoopOp->getResults().drop_front(initSize))) {
+                 newTopLevelLoop->getResults().drop_front(initSize))) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
-  // 13. Need to erase the old scf loop and the cloned consumer op.
+  // 14. Need to erase the old scf loop and the cloned consumer op.
   rewriter.eraseOp(oldLoopOp);
   rewriter.eraseOp(clonedConsumerOp);
 
+  // 15. Need to erase the cloned insertSliceOp and unused extractSliceOp in
+  // avoid of complex domination analysis
+  assert(clonedInsertSliceOp->hasOneUse());
+  auto unUsedExtractOp =
+      cast<tensor::ExtractSliceOp>((*clonedInsertSliceOp->getUsers().begin()));
+  rewriter.eraseOp(unUsedExtractOp);
+  rewriter.eraseOp(clonedInsertSliceOp);
+
   return scf::SCFFuseConsumerOfSliceResult{
       consumerOpOperand,
       &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
       tileAndFuseResult->tiledOps};
 }
 
+/// Get the real consumers from candidate InsertSliceOp. E.g
+///
+/// ```
+/// %1 = scf.for
+///  %2 = scf.for
+///   %3 = scf.for
+///      ...
+///      %4 = insert
+///      yield %4
+///   %5 = insert %3
+///   yield %5
+///  yield %2
+/// %6 = consumerOp ins(%1)
+/// ```
+///
+/// @param candidateSliceOp: %4 = insert
+/// @param forwardSlice: in-out parameter populated by forward insertSliceOps
+/// @return OpOperand consumers: %6 = consumerOp ins(%1)
+static FailureOr<SmallVector<OpOperand *>> getRealConsumersFromInsertSliceOp(
----------------
Yun-Fly wrote:

As I explained in above thread, `tileAndFuseConsumerOfSlice` requires candidate slice to know where to fuse, so that we need to feed corresponding slice to it during iterative process. 

This method is used to collect the chain of all related insert slice during forward finding, saying:

```
%0 = scf.for() {
  ...
  scf.for() {
      %1 = tensor.insert_slice
  }
  ...
  %2 = tensor.insert_slice
}
consumer ins(%0)
```

After fusing consumer first level of loop, it seems hard to feed next inner candidate to `tileAndFuseConsumerOfSlice` without this method..


https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list