[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 9 18:51:47 PDT 2024


================
@@ -1654,52 +1729,99 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   if (isInsertSliceOp) {
     auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
     oldLoopOp = forOp;
-    llvm::append_range(newOuts, forOp.getInits());
-    oldLoopBody = forOp.getBody();
     initSize = forOp.getInits().size();
   } else {
     auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
     oldLoopOp = forallOp;
-    llvm::append_range(newOuts, forallOp.getOutputs());
-    oldLoopBody = forallOp.getBody();
     initSize = forallOp.getOutputs().size();
     rank = forallOp.getRank();
   }
 
-  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+  // There are two possible cases regarding `oldLoopOp` here:
+  // 1. single `scf.forall` or `scf.for`.
+  // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
+  // top-level loop is the outer-most one of these nested loops.
+  Operation *oldTopLevelLoop = oldLoopOp;
+  SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
+  if (isInsertSliceOp) {
+    oldNestedForOps =
+        getOuterNestLoopsWhile(cast<LoopLikeOpInterface>(oldTopLevelLoop),
+                               isForOpYieldResultOfInnerLoop);
+    oldTopLevelLoop = oldNestedForOps.front();
+  }
+
+  if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp))) {
     return rewriter.notifyMatchFailure(
-        oldLoopOp, "containing loop op should either yield just one value or "
-                   "have the consumer op as its first user");
+        oldTopLevelLoop,
+        "containing loop op should either yield just one value or "
+        "have the consumer op as its first user");
   }
 
   OpBuilder::InsertionGuard g(rewriter);
 
   // 2. Check consumer is not using scf loop's output as init.
-  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp)
+    return rewriter.notifyMatchFailure(consumerOp,
+                                       "consumer op is not DPS operation");
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  newOuts.append(dpsInits);
+  SmallVector<Value> newInitAppend = dpsInits;
 
   Location loc = oldLoopOp->getLoc();
 
   // 3. Create new scf loop op.
   rewriter.setInsertionPoint(consumerOp);
+
+  // 3.a Create new outer scf loops with new Inits only if nested `scf.for`
+  // case was found.
+  bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size() > 1;
----------------
MaheshRavishankar wrote:

Ok, Sorry that was my bad... I see that the existing solution is also creating new loops. If you want to add this, then I/you can clean this up after the fact.

Basically I think this would be much easier if we just "moved" the operation before the consumer op and then just use `addInitOperandsToLoopNest` . But I/you can clean this up after the fact as well. I am modiyfing code in this area anyway, so I am fine with taking this up.

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list