[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 10 18:16:00 PDT 2024


================
@@ -1646,81 +1641,58 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
   }
 
-  Operation *oldLoopOp = nullptr;
-  SmallVector<Value> newOuts;
-  Block *oldLoopBody = nullptr;
-  unsigned initSize = 0;
-  unsigned rank = 1;
+  // There are two possible cases regarding `oldLoopOp` here:
+  // 1. single `scf.forall` or `scf.for`.
+  // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
+  // top-level loop is the outer-most one of these nested loops.
+  LoopLikeOpInterface innerMostLoop =
+      candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
+  SmallVector<LoopLikeOpInterface> nestedLoops;
   if (isInsertSliceOp) {
-    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
-    oldLoopOp = forOp;
-    llvm::append_range(newOuts, forOp.getInits());
-    oldLoopBody = forOp.getBody();
-    initSize = forOp.getInits().size();
+    nestedLoops = getPerfectlyOuterNestedLoops(innerMostLoop);
   } else {
-    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
-    oldLoopOp = forallOp;
-    llvm::append_range(newOuts, forallOp.getOutputs());
-    oldLoopBody = forallOp.getBody();
-    initSize = forallOp.getOutputs().size();
-    rank = forallOp.getRank();
+    nestedLoops = {innerMostLoop};
   }
 
-  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+  LoopLikeOpInterface outerMostLoop = nestedLoops.front();
+
+  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
     return rewriter.notifyMatchFailure(
-        oldLoopOp, "containing loop op should either yield just one value or "
-                   "have the consumer op as its first user");
+        outerMostLoop,
+        "containing loop op should either yield just one value or "
+        "have the consumer op as its first user");
   }
 
   OpBuilder::InsertionGuard g(rewriter);
 
   // 2. Check consumer is not using scf loop's output as init.
-  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp)
+    return rewriter.notifyMatchFailure(consumerOp,
+                                       "consumer op is not DPS operation");
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  newOuts.append(dpsInits);
-
-  Location loc = oldLoopOp->getLoc();
+  SmallVector<Value> newInits = dpsInits;
 
-  // 3. Create new scf loop op.
-  rewriter.setInsertionPoint(consumerOp);
-  Operation *newLoopOp = nullptr;
-  Block *newLoopBody = nullptr;
-  if (isInsertSliceOp) {
-    auto forOp = cast<scf::ForOp>(oldLoopOp);
-    auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
-                                                forOp.getUpperBound(),
-                                                forOp.getStep(), newOuts);
-    newLoopOp = newForOp;
-    newLoopBody = newForOp.getBody();
-  } else {
-    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
-    auto newForallOp = rewriter.create<scf::ForallOp>(
-        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
-        forallOp.getMixedStep(), newOuts, forallOp.getMapping());
-    newLoopOp = newForallOp;
-    rewriter.eraseOp(newForallOp.getTerminator());
-    newLoopBody = newForallOp.getBody();
-  }
+  Location loc = outerMostLoop->getLoc();
 
-  // 4. Move the loop body to the new op.
-  unsigned oldNumArguments = oldLoopBody->getNumArguments();
-  rewriter.mergeBlocks(oldLoopBody, newLoopBody,
-                       newLoopBody->getArguments().take_front(oldNumArguments));
+  // 3. Move the whole loop structure right before consumer Op, the dominance
+  // should be already ensured by `checkAssumptionForLoop`.
+  outerMostLoop->moveBefore(consumerOp);
----------------
Yun-Fly wrote:

Thanks for this reminder! Changed.

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list