[Mlir-commits] [mlir] 335538c - Revert "[mlir][scf] Extend consumer fuse to single nested `scf.for` (#94190)"

Kazu Hirata llvmlistbot at llvm.org
Wed Sep 11 19:18:44 PDT 2024


Author: Kazu Hirata
Date: 2024-09-11T19:18:37-07:00
New Revision: 335538c271c9c71ef3f2e23680265e7b77595be0

URL: https://github.com/llvm/llvm-project/commit/335538c271c9c71ef3f2e23680265e7b77595be0
DIFF: https://github.com/llvm/llvm-project/commit/335538c271c9c71ef3f2e23680265e7b77595be0.diff

LOG: Revert "[mlir][scf] Extend consumer fuse to single nested `scf.for` (#94190)"

This reverts commit 2d4bdfba96d4cf88b12226b2b511bf55ee5e6559.

A build breakage is reported at:

https://lab.llvm.org/buildbot/#/builders/138/builds/3524

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 04624638e14c00..e404c01010a325 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1481,50 +1481,6 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
   return &operand;
 }
 
-/// Find the perfectly nested loops outside of given loop(included) sorted from
-/// outer to inner.
-///
-/// E.g.
-///
-/// ```
-///  %0 = scf.for()
-///    %1 = scf.for()
-///      %2 = scf.for()
-///         %3 = ...
-///         yield %3
-///      yield %2
-///    yield %1
-/// ```
-///
-/// This function will return three perfectly nested loops: %0 + %1 + %2, when
-/// target inner loop is %2.
-static SmallVector<scf::ForOp>
-getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
-  SmallVector<scf::ForOp> nestLoops = {loop};
-  auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
-
-  // Check if it is the ForOp that yield the result of inner loop.
-  auto isForOpYieldResultOfInnerLoop =
-      [](scf::ForOp outerLoop) -> LogicalResult {
-    Block *body = outerLoop.getBody();
-    if (!llvm::hasSingleElement(body->without_terminator()))
-      return failure();
-    auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
-    auto innerForOp = dyn_cast<scf::ForOp>(body->front());
-    if (!innerForOp)
-      return failure();
-    // All of innerForOp results should be yielded.
-    return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
-  };
-
-  while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
-    nestLoops.push_back(outerLoop);
-    outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
-  }
-  // sorted from outer to inner
-  return {nestLoops.rbegin(), nestLoops.rend()};
-}
-
 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
 /// tensor.insert_slice. This function makes the following assumptions :
 /// 1.  tensor.insert_slice has scf.yield as its only user.
@@ -1542,10 +1498,9 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
   auto forOp = dyn_cast<scf::ForOp>(containingOp);
   if (!forOp)
     return failure();
-  scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
-  Value resultingValue = topLevelForOp->getResult(resultNumber);
+  Value resultingValue = forOp->getResult(resultNumber);
 
-  return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
+  return getConsumerFromUses(resultingValue, containingOp->getBlock());
 }
 
 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1608,6 +1563,59 @@ static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
   }
 }
 
+/// After fusing consumer into scf.for we want to modify the scf.yield operation
+/// to reflect the same by returning the values yielded by the tiled consumer.
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+                      TilingResult &tilingResult,
+                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                      ArrayRef<BlockArgument> bbArgs) {
+  scf::YieldOp oldTerminatorOp =
+      cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+  unsigned totalOldResults = oldTerminatorOp->getNumResults();
+  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+  SmallVector<Value> newYieldOperands;
+  newYieldOperands.reserve(totalOldResults + totalTiledResults);
+  for (auto oldResult : oldTerminatorOp.getResults()) {
+    newYieldOperands.push_back(oldResult);
+  }
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  Location loc = newForOp.getLoc();
+  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
+                       resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(resultOffset.size(),
+                                      rewriter.getIndexAttr(1));
+    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, tiledResult, bbArg, resultOffset, resultSize, strides);
+    newYieldOperands.push_back(newInsertSliceOp);
+  }
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+}
+
+/// After fusing consumer into scf.forall we want to yield each of the resulting
+/// values by the tiled consumer within scf.forall.in_parallel region.
+static void
+fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
+                           SmallVector<Value> tiledResults,
+                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                           ArrayRef<BlockArgument> bbArgs) {
+  scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
+  rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+  Location firstYieldOpLoc =
+      (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(resultOffset.size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
+  }
+}
+
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1638,63 +1646,81 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
   }
 
-  // There are two possible cases regarding `oldLoopOp` here:
-  // 1. single `scf.forall` or `scf.for`.
-  // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
-  // top-level loop is the outer-most one of these nested loops.
-  LoopLikeOpInterface innerMostLoop =
-      candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
-  SmallVector<LoopLikeOpInterface> nestedLoops;
+  Operation *oldLoopOp = nullptr;
+  SmallVector<Value> newOuts;
+  Block *oldLoopBody = nullptr;
+  unsigned initSize = 0;
+  unsigned rank = 1;
   if (isInsertSliceOp) {
-    nestedLoops = llvm::map_to_vector(
-        getPerfectlyNestedLoopsOutsideOf(
-            cast<scf::ForOp>(innerMostLoop.getOperation())),
-        [](scf::ForOp forOp) {
-          return cast<LoopLikeOpInterface>(forOp.getOperation());
-        });
+    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+    oldLoopOp = forOp;
+    llvm::append_range(newOuts, forOp.getInits());
+    oldLoopBody = forOp.getBody();
+    initSize = forOp.getInits().size();
   } else {
-    nestedLoops = {innerMostLoop};
+    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
+    oldLoopOp = forallOp;
+    llvm::append_range(newOuts, forallOp.getOutputs());
+    oldLoopBody = forallOp.getBody();
+    initSize = forallOp.getOutputs().size();
+    rank = forallOp.getRank();
   }
 
-  LoopLikeOpInterface outerMostLoop = nestedLoops.front();
-
-  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
+  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
     return rewriter.notifyMatchFailure(
-        outerMostLoop,
-        "containing loop op should either yield just one value or "
-        "have the consumer op as its first user");
+        oldLoopOp, "containing loop op should either yield just one value or "
+                   "have the consumer op as its first user");
   }
 
   OpBuilder::InsertionGuard g(rewriter);
 
   // 2. Check consumer is not using scf loop's output as init.
-  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
-  if (!dstOp)
-    return rewriter.notifyMatchFailure(consumerOp,
-                                       "consumer op is not DPS operation");
+  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
   SmallVector<Value> dpsInits =
       llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
-  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
+  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
     return rewriter.notifyMatchFailure(
         consumerOp,
         "consumer op taking the result of scf.for as init is not supported");
   }
-  SmallVector<Value> newInits = dpsInits;
+  newOuts.append(dpsInits);
+
+  Location loc = oldLoopOp->getLoc();
 
-  Location loc = outerMostLoop->getLoc();
+  // 3. Create new scf loop op.
+  rewriter.setInsertionPoint(consumerOp);
+  Operation *newLoopOp = nullptr;
+  Block *newLoopBody = nullptr;
+  if (isInsertSliceOp) {
+    auto forOp = cast<scf::ForOp>(oldLoopOp);
+    auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                                forOp.getUpperBound(),
+                                                forOp.getStep(), newOuts);
+    newLoopOp = newForOp;
+    newLoopBody = newForOp.getBody();
+  } else {
+    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+    auto newForallOp = rewriter.create<scf::ForallOp>(
+        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+        forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+    newLoopOp = newForallOp;
+    rewriter.eraseOp(newForallOp.getTerminator());
+    newLoopBody = newForallOp.getBody();
+  }
 
-  // 3. Move the whole loop structure right before consumer Op, the dominance
-  // should be already ensured by `checkAssumptionForLoop`.
-  rewriter.moveOpBefore(outerMostLoop, consumerOp);
+  // 4. Move the loop body to the new op.
+  unsigned oldNumArguments = oldLoopBody->getNumArguments();
+  rewriter.mergeBlocks(oldLoopBody, newLoopBody,
+                       newLoopBody->getArguments().take_front(oldNumArguments));
 
-  // 4. Set insertion point before terminator op of the loop and create a new
+  // 5. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
   tensor::InsertSliceOp clonedInsertSliceOp;
   if (auto sliceOp =
           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
-    auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
+    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
     rewriter.setInsertionPoint(newForallOp.getTerminator());
     clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
         loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
@@ -1705,17 +1731,20 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
   }
 
-  // 5.a. Clone consumer op.
-  auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
+  // 6.a. Clone consumer op.
+  auto newForOpBlockArgsForConsumerDest =
+      newLoopBody->getArguments().drop_front(oldNumArguments);
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
 
-  // 5.b. Replace all uses of the loop result with the result of the cloned
+  // 6.b. Replace all uses of the loop result with the result of the cloned
   // tensor.insert_slice.
   OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
   rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
     operandToReplace.set(clonedInsertSliceOp.getResult());
   });
 
-  // 6. Perform tiling of the cloned consumer and replace the operand at
+  // 7 - Perform tiling of the cloned consumer and replace the operand at
   // `operandNumber` with the source of the cloned tensor.insert_slice op.
   auto ossSliceOp =
       cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
@@ -1725,105 +1754,79 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   if (failed(tileAndFuseResult)) {
     return failure();
   }
-  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
-  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
-                              clonedInsertSliceOp.getSource());
-
-  // 7. Reconstruct [nested] loop with new inits.
-  YieldTiledValuesFn newYieldValuesFn =
-      [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
-          ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
-          SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
-          SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
-    OpBuilder::InsertionGuard g(innerRewriter);
-    // 8. Set inner insertPoint right before tiled consumer op.
-    innerRewriter.setInsertionPoint(tiledConsumerOp);
-
-    SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+  rewriter.replaceAllUsesWith(
+      tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
+      clonedInsertSliceOp.getSource());
+
+  // 8 - Extract offset/sizes/strides required to create the
+  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+
+  // 9. Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "containingOp's result yield with stride");
+  }
 
-    // 9. Check all insert stride is 1.
-    if (llvm::any_of(strides, [](OpFoldResult stride) {
-          return !isConstantIntValue(stride, 1);
-        })) {
-      return rewriter.notifyMatchFailure(
-          candidateSliceOp, "containingOp's result yield with stride");
-    }
+  // 10. Try to get iter domain position from input position.
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        clonedConsumerOp, "can't get iter domain position from input position");
+  }
 
-    // 10. Try to get iter domain position from input position.
-    SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-    if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
-            rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-            iterDomainSizes))) {
+  // 11. Try to fetch the offset and size for all results of the cloned
+  // consumer. This would then be used to form the corresponding
+  // tensor.insert_slice/parallel_insert_slice later.
+  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(clonedConsumerOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
       return rewriter.notifyMatchFailure(
-          tiledConsumerOp,
-          "can't get iter domain position from input position");
-    }
-
-    // 11. Try to fetch the offset and size for all results of the cloned
-    // consumer. This would then be used to form the corresponding
-    // tensor.insert_slice/parallel_insert_slice later.
-    unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
-    SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-        totalNumResultsOfConsumer);
-    SmallVector<SmallVector<OpFoldResult>> resultSizes(
-        totalNumResultsOfConsumer);
-    for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
-      if (failed(tiledConsumerOp.getResultTilePosition(
-              rewriter, idx, iterDomainOffsets, iterDomainSizes,
-              resultOffsets[idx], resultSizes[idx]))) {
-        return rewriter.notifyMatchFailure(
-            tiledConsumerOp,
-            "can't get result domain position from iter domain position");
-      }
-    }
-
-    // 12. Create `extract_slice` for `iter_args` for DPS operation if
-    // necessary.
-    if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
-            tiledConsumerOp.getOperation())) {
-      rewriter.setInsertionPoint(tiledDestStyleOp);
-      for (const auto &&[index, newRegionArg] :
-           llvm::enumerate(newRegionIterArgs)) {
-        auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
-            loc, newRegionArg, resultOffsets[index], resultSizes[index],
-            SmallVector<OpFoldResult>(resultOffsets[index].size(),
-                                      rewriter.getIndexAttr(1)));
-        rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
-          tiledDestStyleOp.getDpsInitsMutable()[index].set(destSlice);
-        });
-      }
+          clonedConsumerOp,
+          "can't get result domain position from iter domain position");
     }
+  }
 
-    // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
-    // caller.
-    Block *block = rewriter.getInsertionPoint()->getBlock();
-    rewriter.setInsertionPoint(block->getTerminator());
-    for (const auto &&[index, result] :
-         llvm::enumerate(tiledConsumerOp->getResults())) {
-      tiledResult.push_back(result);
-      tiledOffset.emplace_back(resultOffsets[index]);
-      tiledSizes.emplace_back(resultSizes[index]);
-    }
-    return success();
-  };
-  // 14. Add new inits to [nested] loops.
-  if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
-                                       newYieldValuesFn))) {
-    return rewriter.notifyMatchFailure(tiledConsumerOp,
-                                       "unable to add new inits to nest loop");
+  auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
+  auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
+  if (isInsertSliceOp) {
+    auto newForOp = cast<scf::ForOp>(newLoopOp);
+    fixTerminatorSCFYield(
+        rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
+        newForOp.getBody()->getArguments().drop_front(1 + initSize));
+  } else {
+    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    fixTerminatorSCFInParallel(
+        rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
+        arrayRefOffsets, arrayRefSizes,
+        newForallOp.getBody()->getArguments().drop_front(rank + initSize));
   }
 
-  // 15. Replace the result of scf loop and consumer op with new loop's results.
+  // 12. Replace the result of scf loop and consumer op with new loop's results.
+  for (auto &&[oldResult, newResult] :
+       llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
+    rewriter.replaceAllUsesWith(oldResult, newResult);
+  }
 
-  for (auto &&[oldResult, newResult] : llvm::zip(
-           consumerOp->getResults(),
-           nestedLoops.front()->getResults().take_back(newInits.size()))) {
+  for (auto &&[oldResult, newResult] :
+       llvm::zip(consumerOp->getResults(),
+                 newLoopOp->getResults().drop_front(initSize))) {
     rewriter.replaceAllUsesWith(oldResult, newResult);
   }
 
-  // 16. Need to erase the old scf loop and the cloned consumer op.
+  // 13. Need to erase the old scf loop and the cloned consumer op.
+  rewriter.eraseOp(oldLoopOp);
   rewriter.eraseOp(clonedConsumerOp);
 
   return scf::SCFFuseConsumerOfSliceResult{

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index fdefdcc453ae7a..83c5ec8d7342c8 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -109,9 +109,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
 // CHECK-SAME:              outs(%[[SLICE_OUT]] :
 //      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
-//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#2 :
@@ -248,10 +248,10 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
 // CHECK-SAME:              outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
 //      CHECK:      scf.forall.in_parallel {
-//      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
-//      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   %[[UNPACK:.*]] = tensor.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32>
@@ -310,8 +310,8 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:                              outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
 // CHECK-SAME:                              into %[[TILED_UNPACK_DEST]]
 //      CHECK:      scf.forall.in_parallel {
-//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:          tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
+//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#1 :
@@ -369,71 +369,8 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:                              inner_dims_pos = [0] inner_tiles = [16]
 // CHECK-SAME:                              into %[[TILED_PACK_DEST]]
 //      CHECK:      scf.forall.in_parallel {
-//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
 //      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]],  %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
-
-// -----
-
-module {
-  func.func @fuse_add_consumer_into_nested_scf_for(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
-    %c0 = arith.constant 0 : index
-    %c64 = arith.constant 64 : index
-    %c256 = arith.constant 256 : index
-    %cst = arith.constant 0.000000e+00 : f32
-    %dest0 = tensor.empty() : tensor<256x256xf32>
-    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
-    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest1) -> (tensor<256x256xf32>) {
-      %2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %arg4) -> (tensor<256x256xf32>) {
-        %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
-        %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32>
-        %extracted_slice_3 = tensor.extract_slice %arg1[0, %arg5] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
-        %3 = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_1 : tensor<64x64xf32>) -> tensor<64x64xf32>
-        %insert_slice = tensor.insert_slice %3 into %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32>
-        scf.yield %insert_slice : tensor<256x256xf32>
-      }
-      scf.yield %2 : tensor<256x256xf32>
-    }
-    %4 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
-    return %4 : tensor<256x256xf32>
-  }
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    %a, %b = transform.test.fuse_consumer %slice_op
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-//      CHECK: func.func @fuse_add_consumer_into_nested_scf_for(
-// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
-// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
-// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
-//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
-//      CHECK:   %[[dest1:.*]] = linalg.fill
-// CHECK-SAME:          outs(%[[dest0]] :
-//      CHECK:   %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
-// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[dest0]])
-// CHECK-SAME:   {
-//      CHECK:       %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
-// CHECK-SAME:         iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
-// CHECK-SAME:         {
-//      CHECK:            %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
-//      CHECK:            %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
-//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
-// CHECK-SAME:                  outs(%[[MAT_OUT_SLICE]] :
-//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
-// CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
-// CHECK-SAME:              outs(%[[ADD_OUT_SLICE]] :
-//      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
-//      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
-//      CHECK:         }
-//      CHECK:         scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
+//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:       }
 //      CHECK:   }
-//      CHECK:   return %[[LOOP_RESULT1]]#1 :
+//      CHECK:   return %[[FINAL_RESULT]]#1 :


        


More information about the Mlir-commits mailing list