[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 4 21:27:06 PDT 2024


================
@@ -1316,187 +1501,288 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
   }
 
-  Operation *oldLoopOp = nullptr;
-  SmallVector<Value> newOuts;
-  Block *oldLoopBody = nullptr;
-  unsigned initSize = 0;
-  unsigned rank = 1;
-  if (isInsertSliceOp) {
-    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
-    oldLoopOp = forOp;
-    llvm::append_range(newOuts, forOp.getInits());
-    oldLoopBody = forOp.getBody();
-    initSize = forOp.getInits().size();
-  } else {
-    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
-    oldLoopOp = forallOp;
-    llvm::append_range(newOuts, forallOp.getOutputs());
-    oldLoopBody = forallOp.getBody();
-    initSize = forallOp.getOutputs().size();
-    rank = forallOp.getRank();
-  }
-
-  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
-    return rewriter.notifyMatchFailure(
-        oldLoopOp, "containing loop op should either yield just one value or "
-                   "have the consumer op as its first user");
-  }
-
-  OpBuilder::InsertionGuard g(rewriter);
-
-  // 2. Check consumer is not using scf loop's output as init.
+  // 3. Check consumer is not using outerMostLoop's output as init.
   auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  newOuts.append(dpsInits);
+  ValueRange newInitAppend = dpsInits;
 
-  Location loc = oldLoopOp->getLoc();
+  // 4. reconstruct nested loop from outer to inner.
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList =
+      (*resultAndSliceOpsPair).second;
+  SmallVector<LoopLikeOpInterface> newOuterLoops;
+  SmallVector<SmallVector<OpFoldResult>> allResultOffsets, allResultSizes;
+  // extract slice from newInits of outer-most scf.forall
+  SmallVector<tensor::ExtractSliceOp> newExtractOps;
 
-  // 3. Create new scf loop op.
-  rewriter.setInsertionPoint(consumerOp);
-  Operation *newLoopOp = nullptr;
+  Block *oldLoopBody = nullptr;
   Block *newLoopBody = nullptr;
+  SmallVector<Value> newOuts;
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // set insertPoint right before consumerOp
+  rewriter.setInsertionPoint(consumerOp);
+
+  for (auto [index, loop] :
+       llvm::enumerate(MutableArrayRef(outerLoops).drop_back())) {
+    if (index > 0)
+      rewriter.setInsertionPoint(loop);
+
+    LoopLikeOpInterface newLoopOp;
+    // Create a new loop with the new init values for this loop.
+    if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
+      newOuts = llvm::to_vector(forOp.getInits());
+      newOuts.append(newInitAppend.begin(), newInitAppend.end());
+      auto newLoop = rewriter.create<scf::ForOp>(
+          forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+          forOp.getStep(), newOuts);
+      newLoopOp = newLoop;
+      oldLoopBody = forOp.getBody();
+      newLoopBody = newLoop.getBody();
+      newInitAppend =
+          newLoopBody->getArguments().take_back(newInitAppend.size());
+    } else if (auto forallOp = dyn_cast<scf::ForallOp>(loop.getOperation())) {
+      newOuts = llvm::to_vector(forallOp.getOutputs());
+      newOuts.append(newInitAppend.begin(), newInitAppend.end());
+      auto newLoop = rewriter.create<scf::ForallOp>(
+          forallOp.getLoc(), forallOp.getMixedLowerBound(),
+          forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+          forallOp.getMapping());
+      rewriter.eraseOp(newLoop.getTerminator());
+      newLoopOp = newLoop;
+      oldLoopBody = forallOp.getBody();
+      newLoopBody = newLoop.getBody();
+
+      // create extractSliceOp for newInits
+      assert(index == 0 && "Currently Only outerMostLoop assumed ForallOp");
+      auto outerMostCandidate = candidateSliceOpList.back();
+      assert(isa<tensor::ParallelInsertSliceOp>(outerMostCandidate));
+      // set InsertPoint before next inner loop
+      auto nextLoop = outerLoops[index + 1];
+      rewriter.setInsertionPoint(nextLoop);
+      if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+              rewriter, cast<TilingInterface>(consumerOp), operandNumber,
+              outerMostCandidate, allResultOffsets, allResultSizes))) {
+        return failure();
+      }
+      fixSharedOutSCFForall(rewriter, newLoop, nextLoop, allResultOffsets,
+                            allResultSizes, newInitAppend.size(),
+                            newExtractOps);
+      newInitAppend = llvm::map_to_vector(
+          newExtractOps,
+          [](tensor::ExtractSliceOp op) -> Value { return op.getResult(); });
+    }
+    rewriter.mergeBlocks(
+        oldLoopBody, newLoopBody,
+        newLoopBody->getArguments().take_front(oldLoopBody->getNumArguments()));
+    rewriter.replaceOp(
+        loop, newLoopOp->getResults().take_front(loop->getNumResults()));
+    newOuterLoops.push_back(newLoopOp);
+  }
+
+  // 5.a reconstruct inner-most loop.
+  LoopLikeOpInterface oldInnerMostLoop = outerLoops.back(), newInnerMostLoop;
+  Location loc = oldInnerMostLoop->getLoc();
+  if (outerLoops.size() > 1)
+    rewriter.setInsertionPoint(oldInnerMostLoop);
+
   if (isInsertSliceOp) {
-    auto forOp = cast<scf::ForOp>(oldLoopOp);
+    auto forOp = cast<scf::ForOp>(oldInnerMostLoop.getOperation());
+    newOuts = llvm::to_vector(forOp.getInits());
+    newOuts.append(newInitAppend.begin(), newInitAppend.end());
+    oldLoopBody = forOp.getBody();
     auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
                                                 forOp.getUpperBound(),
                                                 forOp.getStep(), newOuts);
-    newLoopOp = newForOp;
+    newInnerMostLoop = newForOp;
     newLoopBody = newForOp.getBody();
   } else {
-    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+    auto forallOp = cast<scf::ForallOp>(oldInnerMostLoop.getOperation());
+    newOuts = llvm::to_vector(forallOp.getOutputs());
+    newOuts.append(newInitAppend.begin(), newInitAppend.end());
+    oldLoopBody = forallOp.getBody();
     auto newForallOp = rewriter.create<scf::ForallOp>(
         loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
         forallOp.getMixedStep(), newOuts, forallOp.getMapping());
-    newLoopOp = newForallOp;
+    newInnerMostLoop = newForallOp;
     rewriter.eraseOp(newForallOp.getTerminator());
     newLoopBody = newForallOp.getBody();
   }
 
-  // 4. Move the loop body to the new op.
+  // 5.b Move the loop body to the new op.
   unsigned oldNumArguments = oldLoopBody->getNumArguments();
   rewriter.mergeBlocks(oldLoopBody, newLoopBody,
                        newLoopBody->getArguments().take_front(oldNumArguments));
+  // 5.c replace the result of old oldInnerMostLoop with newInnerMostLoop's
+  // results.
+  rewriter.replaceOp(oldInnerMostLoop,
+                     newInnerMostLoop->getResults().take_front(
+                         oldInnerMostLoop->getNumResults()));
 
-  // 5. Set insertion point before terminator op of the loop and create a new
+  // 6. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
   tensor::InsertSliceOp clonedInsertSliceOp;
   if (auto sliceOp =
           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
-    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
     rewriter.setInsertionPoint(newForallOp.getTerminator());
-    clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
   } else {
     rewriter.setInsertionPoint(candidateSliceOp);
-    clonedInsertSliceOp =
-        cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
   }
-
-  // 6.a. Clone consumer op.
-  auto newForOpBlockArgsForConsumerDest =
-      newLoopBody->getArguments().drop_front(oldNumArguments);
+  FailureOr<SmallVector<OpFoldResult>> realOffsets =
+      computeRealOffsetsCoordinatedRootSliceOp(rewriter, loc, ossSliceOp,
+                                               candidateSliceOpList);
+  if (failed(realOffsets))
+    return failure();
+  // create dummy insertSliceOp to align with the requirement of current
+  // Tiling interface and fix potential semantic mismatch with later
+  // extractSliceOp generated by `getTiledImplementation`.
+  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, candidateSliceOp->getOperand(0),
+      candidateSliceOpList.back()->getOperand(1), *realOffsets,
+      ossSliceOp.getMixedSizes(), ossSliceOp.getMixedStrides());
+
+  // 7.a. Clone consumer op.
+  SmallVector<Value> newDpsInitsForConsumerDest = llvm::map_to_vector(
+      newLoopBody->getArguments().drop_front(oldNumArguments),
+      [](BlockArgument bArg) -> Value { return bArg; });
+  ;
+  // align dps inits if necessary
----------------
Yun-Fly wrote:

Hi, @Abhishek-Varma. I have push the change, please take a look.

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list