[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 16 09:32:10 PDT 2024


================
@@ -1418,52 +1560,93 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   if (isInsertSliceOp) {
     auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
     oldLoopOp = forOp;
-    llvm::append_range(newOuts, forOp.getInits());
-    oldLoopBody = forOp.getBody();
     initSize = forOp.getInits().size();
   } else {
     auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
     oldLoopOp = forallOp;
-    llvm::append_range(newOuts, forallOp.getOutputs());
-    oldLoopBody = forallOp.getBody();
     initSize = forallOp.getOutputs().size();
     rank = forallOp.getRank();
   }
 
-  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+  Operation *oldTopLevelLoop = oldLoopOp;
+  SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
+  if (isInsertSliceOp) {
+    oldNestedForOps =
+        getOuterNestLoopsWhile(cast<LoopLikeOpInterface>(oldTopLevelLoop),
+                               isForOpYieldResultOfInnerLoop);
+    oldTopLevelLoop = oldNestedForOps.front();
+  }
+  // 2.a Check assumption for loop and find suitable insertPoint that loop
+  // structure would be cloned right before.
+  Operation *insertPointBefore = nullptr;
+  if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp,
+                                    &insertPointBefore))) {
     return rewriter.notifyMatchFailure(
-        oldLoopOp, "containing loop op should either yield just one value or "
-                   "have the consumer op as its first user");
+        oldTopLevelLoop, "containing loop op does not satisfy the assumption "
+                         "and no suitable insertPoint is found");
   }
 
   OpBuilder::InsertionGuard g(rewriter);
 
-  // 2. Check consumer is not using scf loop's output as init.
+  // 2.b Check consumer is not using scf loop's output as init.
   auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  newOuts.append(dpsInits);
+  SmallVector<Value> newInitAppend = dpsInits;
 
   Location loc = oldLoopOp->getLoc();
 
   // 3. Create new scf loop op.
-  rewriter.setInsertionPoint(consumerOp);
+  rewriter.setInsertionPoint(insertPointBefore);
+
+  // 3.a Create new outer scf loops if necessary
----------------
MaheshRavishankar wrote:

I am missing why we need to create new scf.fors... 

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list