[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #108318)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 11 19:42:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: None (Yun-Fly)

<details>
<summary>Changes</summary>

This is a mirror PR of #<!-- -->94190 with tiny build fix. 

Sorry for your inconvenience.

---

Patch is 28.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108318.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+174-174) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+70-7) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..f4cf92201068ae 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1481,6 +1481,50 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
   return &operand;
 }
 
+/// Find the perfectly nested loops outside of given loop(included) sorted from
+/// outer to inner.
+///
+/// E.g.
+///
+/// ```
+///  %0 = scf.for()
+///    %1 = scf.for()
+///      %2 = scf.for()
+///         %3 = ...
+///         yield %3
+///      yield %2
+///    yield %1
+/// ```
+///
+/// This function will return three perfectly nested loops: %0 + %1 + %2, when
+/// target inner loop is %2.
+static SmallVector<scf::ForOp>
+getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
+  SmallVector<scf::ForOp> nestLoops = {loop};
+  auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
+
+  // Check if it is the ForOp that yield the result of inner loop.
+  auto isForOpYieldResultOfInnerLoop =
+      [](scf::ForOp outerLoop) -> LogicalResult {
+    Block *body = outerLoop.getBody();
+    if (!llvm::hasSingleElement(body->without_terminator()))
+      return failure();
+    auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
+    auto innerForOp = dyn_cast<scf::ForOp>(body->front());
+    if (!innerForOp)
+      return failure();
+    // All of innerForOp results should be yielded.
+    return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
+  };
+
+  while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
+    nestLoops.push_back(outerLoop);
+    outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
+  }
+  // sorted from outer to inner
+  return {nestLoops.rbegin(), nestLoops.rend()};
+}
+
 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
 /// tensor.insert_slice. This function makes the following assumptions :
 /// 1.  tensor.insert_slice has scf.yield as its only user.
@@ -1498,9 +1542,10 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
   auto forOp = dyn_cast<scf::ForOp>(containingOp);
   if (!forOp)
     return failure();
-  Value resultingValue = forOp->getResult(resultNumber);
+  scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
+  Value resultingValue = topLevelForOp->getResult(resultNumber);
 
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
+  return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
 }
 
 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1563,59 +1608,6 @@ static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
   }
 }
 
-/// After fusing consumer into scf.for we want to modify the scf.yield operation
-/// to reflect the same by returning the values yielded by the tiled consumer.
-static void
-fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
-                      TilingResult &tilingResult,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
-                      ArrayRef<BlockArgument> bbArgs) {
-  scf::YieldOp oldTerminatorOp =
-      cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
-  unsigned totalOldResults = oldTerminatorOp->getNumResults();
-  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
-  SmallVector<Value> newYieldOperands;
-  newYieldOperands.reserve(totalOldResults + totalTiledResults);
-  for (auto oldResult : oldTerminatorOp.getResults()) {
-    newYieldOperands.push_back(oldResult);
-  }
-  rewriter.setInsertionPointAfter(oldTerminatorOp);
-  Location loc = newForOp.getLoc();
-  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
-                       resultOffsets, resultSizes)) {
-    SmallVector<OpFoldResult> strides(resultOffset.size(),
-                                      rewriter.getIndexAttr(1));
-    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, tiledResult, bbArg, resultOffset, resultSize, strides);
-    newYieldOperands.push_back(newInsertSliceOp);
-  }
-  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
-  rewriter.eraseOp(oldTerminatorOp);
-}
-
-/// After fusing consumer into scf.forall we want to yield each of the resulting
-/// values by the tiled consumer within scf.forall.in_parallel region.
-static void
-fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
-                           SmallVector<Value> tiledResults,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
-                           ArrayRef<BlockArgument> bbArgs) {
-  scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
-  rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
-  Location firstYieldOpLoc =
-      (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
-  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
-    SmallVector<OpFoldResult> strides(resultOffset.size(),
-                                      rewriter.getIndexAttr(1));
-    rewriter.create<tensor::ParallelInsertSliceOp>(
-        firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
-  }
-}
-
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1646,81 +1638,63 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
   }
 
-  Operation *oldLoopOp = nullptr;
-  SmallVector<Value> newOuts;
-  Block *oldLoopBody = nullptr;
-  unsigned initSize = 0;
-  unsigned rank = 1;
+  // There are two possible cases regarding `oldLoopOp` here:
+  // 1. single `scf.forall` or `scf.for`.
+  // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
+  // top-level loop is the outer-most one of these nested loops.
+  LoopLikeOpInterface innerMostLoop =
+      candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
+  SmallVector<LoopLikeOpInterface> nestedLoops;
   if (isInsertSliceOp) {
-    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
-    oldLoopOp = forOp;
-    llvm::append_range(newOuts, forOp.getInits());
-    oldLoopBody = forOp.getBody();
-    initSize = forOp.getInits().size();
+    nestedLoops = llvm::map_to_vector(
+        getPerfectlyNestedLoopsOutsideOf(
+            cast<scf::ForOp>(innerMostLoop.getOperation())),
+        [](scf::ForOp forOp) {
+          return cast<LoopLikeOpInterface>(forOp.getOperation());
+        });
   } else {
-    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
-    oldLoopOp = forallOp;
-    llvm::append_range(newOuts, forallOp.getOutputs());
-    oldLoopBody = forallOp.getBody();
-    initSize = forallOp.getOutputs().size();
-    rank = forallOp.getRank();
+    nestedLoops = {innerMostLoop};
   }
 
-  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+  LoopLikeOpInterface outerMostLoop = nestedLoops.front();
+
+  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
     return rewriter.notifyMatchFailure(
-        oldLoopOp, "containing loop op should either yield just one value or "
-                   "have the consumer op as its first user");
+        outerMostLoop,
+        "containing loop op should either yield just one value or "
+        "have the consumer op as its first user");
   }
 
   OpBuilder::InsertionGuard g(rewriter);
 
   // 2. Check consumer is not using scf loop's output as init.
-  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp)
+    return rewriter.notifyMatchFailure(consumerOp,
+                                       "consumer op is not DPS operation");
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  newOuts.append(dpsInits);
-
-  Location loc = oldLoopOp->getLoc();
+  SmallVector<Value> newInits = dpsInits;
 
-  // 3. Create new scf loop op.
-  rewriter.setInsertionPoint(consumerOp);
-  Operation *newLoopOp = nullptr;
-  Block *newLoopBody = nullptr;
-  if (isInsertSliceOp) {
-    auto forOp = cast<scf::ForOp>(oldLoopOp);
-    auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
-                                                forOp.getUpperBound(),
-                                                forOp.getStep(), newOuts);
-    newLoopOp = newForOp;
-    newLoopBody = newForOp.getBody();
-  } else {
-    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
-    auto newForallOp = rewriter.create<scf::ForallOp>(
-        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
-        forallOp.getMixedStep(), newOuts, forallOp.getMapping());
-    newLoopOp = newForallOp;
-    rewriter.eraseOp(newForallOp.getTerminator());
-    newLoopBody = newForallOp.getBody();
-  }
+  Location loc = outerMostLoop->getLoc();
 
-  // 4. Move the loop body to the new op.
-  unsigned oldNumArguments = oldLoopBody->getNumArguments();
-  rewriter.mergeBlocks(oldLoopBody, newLoopBody,
-                       newLoopBody->getArguments().take_front(oldNumArguments));
+  // 3. Move the whole loop structure right before consumer Op, the dominance
+  // should be already ensured by `checkAssumptionForLoop`.
+  rewriter.moveOpBefore(outerMostLoop, consumerOp);
 
-  // 5. Set insertion point before terminator op of the loop and create a new
+  // 4. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
   tensor::InsertSliceOp clonedInsertSliceOp;
   if (auto sliceOp =
           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
-    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
     rewriter.setInsertionPoint(newForallOp.getTerminator());
     clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
         loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
@@ -1731,20 +1705,17 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
   }
 
-  // 6.a. Clone consumer op.
-  auto newForOpBlockArgsForConsumerDest =
-      newLoopBody->getArguments().drop_front(oldNumArguments);
-  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
-      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+  // 5.a. Clone consumer op.
+  auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
 
-  // 6.b. Replace all uses of the loop result with the result of the cloned
+  // 5.b. Replace all uses of the loop result with the result of the cloned
   // tensor.insert_slice.
   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
     operandToReplace.set(clonedInsertSliceOp.getResult());
   });
 
-  // 7 - Perform tiling of the cloned consumer and replace the operand at
+  // 6. Perform tiling of the cloned consumer and replace the operand at
   // `operandNumber` with the source of the cloned tensor.insert_slice op.
   auto ossSliceOp =
       cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
@@ -1754,79 +1725,108 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   if (failed(tileAndFuseResult)) {
     return failure();
   }
-  rewriter.replaceAllUsesWith(
-      tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
-      clonedInsertSliceOp.getSource());
-
-  // 8 - Extract offset/sizes/strides required to create the
-  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
-  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
-  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
-  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
-  // 9. Check all insert stride is 1.
-  if (llvm::any_of(strides, [](OpFoldResult stride) {
-        return !isConstantIntValue(stride, 1);
-      })) {
-    return rewriter.notifyMatchFailure(
-        candidateSliceOp, "containingOp's result yield with stride");
-  }
+  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
+  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
+                              clonedInsertSliceOp.getSource());
 
-  // 10. Try to get iter domain position from input position.
-  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-          iterDomainSizes))) {
-    return rewriter.notifyMatchFailure(
-        clonedConsumerOp, "can't get iter domain position from input position");
-  }
+  // 7. Reconstruct [nested] loop with new inits.
+  YieldTiledValuesFn newYieldValuesFn =
+      [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+          ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+          SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+          SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+    OpBuilder::InsertionGuard g(innerRewriter);
+    // 8. Set inner insertPoint right before tiled consumer op.
+    innerRewriter.setInsertionPoint(tiledConsumerOp);
 
-  // 11. Try to fetch the offset and size for all results of the cloned
-  // consumer. This would then be used to form the corresponding
-  // tensor.insert_slice/parallel_insert_slice later.
-  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
-  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-      totalNumResultsOfConsumer);
-  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
-  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
-    if (failed(clonedConsumerOp.getResultTilePosition(
-            rewriter, idx, iterDomainOffsets, iterDomainSizes,
-            resultOffsets[idx], resultSizes[idx]))) {
+    SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+
+    // 9. Check all insert stride is 1.
+    if (llvm::any_of(strides, [](OpFoldResult stride) {
+          return !isConstantIntValue(stride, 1);
+        })) {
       return rewriter.notifyMatchFailure(
-          clonedConsumerOp,
-          "can't get result domain position from iter domain position");
+          candidateSliceOp, "containingOp's result yield with stride");
     }
-  }
 
-  auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
-  auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
-  if (isInsertSliceOp) {
-    auto newForOp = cast<scf::ForOp>(newLoopOp);
-    fixTerminatorSCFYield(
-        rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
-        newForOp.getBody()->getArguments().drop_front(1 + initSize));
-  } else {
-    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
-    fixTerminatorSCFInParallel(
-        rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
-        arrayRefOffsets, arrayRefSizes,
-        newForallOp.getBody()->getArguments().drop_front(rank + initSize));
-  }
+    // 10. Try to get iter domain position from input position.
+    SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+    if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
+            rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+            iterDomainSizes))) {
+      return rewriter.notifyMatchFailure(
+          tiledConsumerOp,
+          "can't get iter domain position from input position");
+    }
 
-  // 12. Replace the result of scf loop and consumer op with new loop's results.
-  for (auto &&[oldResult, newResult] :
-       llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
-    rewriter.replaceAllUsesWith(oldResult, newResult);
+    // 11. Try to fetch the offset and size for all results of the cloned
+    // consumer. This would then be used to form the corresponding
+    // tensor.insert_slice/parallel_insert_slice later.
+    unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
+    SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+        totalNumResultsOfConsumer);
+    SmallVector<SmallVector<OpFoldResult>> resultSizes(
+        totalNumResultsOfConsumer);
+    for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
+      if (failed(tiledConsumerOp.getResultTilePosition(
+              rewriter, idx, iterDomainOffsets, iterDomainSizes,
+              resultOffsets[idx], resultSizes[idx]))) {
+        return rewriter.notifyMatchFailure(
+            tiledConsumerOp,
+            "can't get result domain position from iter domain position");
+      }
+    }
+
+    // 12. Create `extract_slice` for `iter_args` for DPS operation if
+    // necessary.
+    if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
+            tiledConsumerOp.getOperation())) {
+      rewriter.setInsertionPoint(tiledDestStyleOp);
+      for (const auto &&[index, newRegionArg] :
+           llvm::enumerate(newRegionIterArgs)) {
+        auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
+            loc, newRegionArg, resultOffsets[index], resultSizes[index],
+            SmallVector<OpFoldResult>(resultOffsets[index].size(),
+                                      rewriter.getIndexAttr(1)));
+        // Make C++ 17 happy, otherwise it will throw error `captured structured
+        // bindings are a C++20 extension`.
+        auto dstNumber = index;
+        rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
+          tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
+        });
+      }
+    }
+
+    // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
+    // caller.
+    Block *block = rewriter.getInsertionPoint()->getBlock();
+    rewriter.setInsertionPoint(block->getTerminator());
+    for (const auto &&[index, result] :
+         llvm::enumerate(tiledConsumerOp->getResults())) {
+      tiledResult.push_back(result);
+      tiledOffset.emplace_back(resultOffsets[index]);
+      tiledSizes.emplace_back(resultSizes[index]);
+    }
+    return success();
+  };
+  // 14. Add new inits to [nested] loops.
+  if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
+                                       newYieldValuesFn))) {
+    return rewriter.notifyMatchFailure(tiledConsumerOp,
+                                       "unable to add new inits to nest loop");
   }
 
-  for (auto &&[oldResult, newResult] :
-       llvm::zip(consumerOp->getResults(),
-                 newLoopOp->getResults().drop_front(initSize))) {
+  // 15. Replace the result of scf loop and consumer op with new loop's results.
+
+  for (auto &&[oldResult, newResult] : llvm::zip(
+           consumerOp->getResults(),
+           nestedLoops.front()->getResults().take_back(newInits.size()))) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
-  // 13. Need to erase the old scf loop and the cloned consumer op.
-  rewriter.eraseOp(oldLoopOp);
+  // 16. Need to erase the old scf loop and the cloned consumer op.
   rewriter.eraseOp(clonedConsumerOp);
 
   return scf::SCF...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/108318


More information about the Mlir-commits mailing list