[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Abhishek Varma llvmlistbot at llvm.org
Mon Apr 15 04:11:11 PDT 2024


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/88712

>From 8e044b2e584383e6d74ff0b1d2a915264179b6ec Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 10 Apr 2024 10:41:46 +0000
Subject: [PATCH] [MLIR][SCF] Add an API to fuse consumer to a producer within
 scf loop

-- This commit adds an API to fuse consumer to a producer within
   scf.for/scf.forall loop.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../SCF/Transforms/TileUsingInterface.h       |  13 +
 .../mlir/Interfaces/TilingInterface.td        |  55 ++
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  96 +++-
 .../SCF/Transforms/TileUsingInterface.cpp     | 477 ++++++++++++++++++
 .../TilingInterface/fuse-consumer.mlir        | 110 ++++
 .../TestTilingInterfaceTransformOps.cpp       |  53 ++
 .../TestTilingInterfaceTransformOps.td        |  21 +
 7 files changed, 804 insertions(+), 21 deletions(-)
 create mode 100644 mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be28..75b48a2cdd8dc3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -126,6 +126,19 @@ struct SCFTileAndFuseOptions {
   }
 };
 
+/// Fuse the consumer of the source of `candidateSliceOp` by computing the
+/// required slice of the consumer in-place.  Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
+/// value but does not delete the slice operation.
+struct SCFFuseConsumerOfSliceResult {
+  Operation *origConsumer;     // Original untiled consumer.
+  Value tiledAndFusedConsumer; // Tile and fused consumer value.
+  SmallVector<Operation *> tiledOps;
+};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
+                           bool useSCFFor);
+
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
 /// required slice of the producer in-place.  Note that the method
 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return {};
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return iterator domain position computed by the
+          input operand position.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVector<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand position.
+
+          Generates the IR that generate the tiled implementation of an
+          operation from operand position.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the position of the
+          operand required. This method enables fusion consumer by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..01bf19764b0938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
+  void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
+                              AffineMap indexingMap,
+                              ArrayRef<OpFoldResult> offsets,
+                              ArrayRef<OpFoldResult> sizes,
+                              SmallVector<OpFoldResult> &mappedOffsets,
+                              SmallVector<OpFoldResult> &mappedSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+    auto numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &range : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[range.index()] = range.value().offset;
+        mappedSizes[range.index()] = range.value().size;
+      }
+    }
+    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition =
+          cast<AffineDimExpr>(resultExpr.value()).getPosition();
+      mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+      mappedSizes[dimPosition] = sizes[resultExpr.index()];
+    }
+  }
+
+  // Return the details of the output tile generated by the tiled
+  // implementation.
+  LogicalResult getIterDomainTilePositionFromOperandPosition(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVector<OpFoldResult> &iterDomainOffsets,
+      SmallVector<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return op->emitOpError(
+          "unhandled get iter domain position when operand is not "
+          "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
   // Return the details of the output tile generated by the tiled
   // implementation.
   LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandPosition(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return op->emitOpError(
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes, mappedOffsets,
+                           mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69df..74202c8f9d4e9f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -798,6 +799,59 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 // tileConsumerAndFuseProducersUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+                                      unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.for.
+  unsigned yieldOperandNumber = source.getOperandNumber();
+  Value resultingValue =
+      source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFForall(OpOperand *source,
+                                         unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.forall.
+  auto iterArg = dyn_cast<BlockArgument>(source->get());
+  Value resultingValue = iterArg.getOwner()->getParentOp()->getResult(0);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
 /// Return the untiled producer whose slice is used in a tiled consumer. The
 /// method traverses the tile loop nest (`loops`) if needed, and returns the
 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
@@ -820,6 +874,429 @@ getUntiledProducerFromSliceSource(OpOperand *source,
   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
 }
 
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the source.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] =
+      getUntiledConsumerFromSliceDestSCFForall(
+          &candidateSliceOp.getDestMutable(), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check containing op is "scf::ForallOp".
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    llvm::outs() << "containing op is not a scf.forall: " << containingOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        llvm::outs()
+            << "consumer's operand use more than one containingOp's result\n";
+        return failure();
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    llvm::outs() << "consumer op should have destination style op interface"
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer doon't use scf.forall's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, bridge)) {
+    llvm::outs() << "consumer op take result of scf.forall as init\n";
+    return failure();
+  }
+
+  // Check result was inserted only once.
+  int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+  scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+        if (auto attr = foldRes.dyn_cast<Attribute>()) {
+          return cast<IntegerAttr>(attr).getInt() != 1;
+        }
+        return true;
+      })) {
+    llvm::outs() << "containingOp's result yield with stride\n";
+    return failure();
+  }
+
+  Location loc = forallOp.getLoc();
+  rewriter.setInsertionPoint(terminatorOp);
+
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+  // Try to get iter domain position from input position.
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    llvm::outs() << "can't get iter domain position from input position\n";
+    return failure();
+  }
+
+  // Try to get all containing op result's position from iter domain position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(consumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+    if (failed(tileableConsumer.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      llvm::outs()
+          << "can't get result domain position from iter domain position\n";
+      return failure();
+    }
+  }
+
+  // All check passed, try to fuse consumer.
+  // Create tiled implementation of containing op.
+  FailureOr<TilingResult> tileAndFuseResult =
+      tileableConsumer.getTiledImplementationFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes);
+  if (failed(tileAndFuseResult)) {
+    llvm::outs() << "get tiled implementation failed\n";
+    return failure();
+  }
+
+  auto tiledOps = tileAndFuseResult->tiledOps;
+  if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+    llvm::outs() << "failed to tile consumer op: " << *tileableConsumer << "\n";
+    return failure();
+  }
+
+  // Replace tiled op's operand.
+  for (auto operandNum : operandNums) {
+    tiledOps[0]->setOperand(operandNum, candidateSliceOp.getSource());
+  }
+  rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+                             [&](OpOperand &use) {
+                               Operation *op = use.getOwner();
+                               return forallOp->isProperAncestor(op);
+                             });
+
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+
+  // Create new scf.forall op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+  rewriter.eraseBlock(newforallOp.getBody());
+  newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+  for (auto v : dpsInits) {
+    newforallOp.getBody()->addArgument(v.getType(), v.getLoc());
+    auto bbArgs = newforallOp.getBody()->getArguments();
+    rewriter.replaceUsesWithIf(v, bbArgs.back(), [&](OpOperand &use) {
+      Operation *op = use.getOwner();
+      return newforallOp->isProperAncestor(op);
+    });
+  }
+
+  // Fix terminator.
+  scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+      newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+  Operation *firstYieldOp = yieldingOps.front();
+  rewriter.setInsertionPoint(firstYieldOp);
+  auto bbArgs = newforallOp.getBody()->getArguments();
+  for (auto [idx, v] : llvm::enumerate(tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOp->getLoc(), v,
+        bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+        resultPositions[idx].first, resultPositions[idx].second, strides);
+  }
+
+  // Replace the result of forall and consumer op.
+  for (auto result : llvm::enumerate(forallOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforallOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforallOp->getResult(forallOp.getOutputs().size() +
+                               consumerResult.index()));
+  }
+
+  // Need to erase the old forall.
+  rewriter.eraseOp(forallOp);
+
+  return scf::SCFFuseConsumerOfSliceResult{
+      consumerOp, tiledOps[0]->getResult(0), {}};
+}
+
+static bool
+checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
+  Value result = candidateSliceOp.getResult();
+  Value::user_range users = result.getUsers();
+  if (std::distance(users.begin(), users.end()) != 1) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return false;
+  }
+  if (!isa<scf::YieldOp>(*users.begin())) {
+    LLVM_DEBUG(llvm::dbgs() << "Expected scf.yield to be the only user\n");
+    return false;
+  }
+  return true;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+                                 tensor::InsertSliceOp candidateSliceOp) {
+  // ASSUMING THAT YIELD OP IS ONLY YIELDING JUST ONE VALUE.
+  if (!checkAssumptionForFusingConsumer(candidateSliceOp)) {
+    return failure();
+  }
+  // 1. Get the consumer of the source.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] = getUntiledConsumerFromSliceDestSCFFor(
+      candidateSliceOp->getOpOperand(operandNumber), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(candidateSliceOp);
+
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check containing op is "scf::ForOp".
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp) {
+    llvm::outs() << "containing op is not a scf.for: " << containingOp << "\n";
+    return failure();
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        llvm::outs()
+            << "consumer's operand use more than one containingOp's result\n";
+        return failure();
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    llvm::outs() << "consumer op should have destination style op interface\n";
+    return failure();
+  }
+
+  // Check consumer is not using scf.for's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, bridge)) {
+    llvm::outs() << "consumer op take result of scf.for as init\n";
+    return failure();
+  }
+
+  // Check result was inserted only once.
+  int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+
+  Location loc = forOp.getLoc();
+  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+        if (auto attr = foldRes.dyn_cast<Attribute>()) {
+          return cast<IntegerAttr>(attr).getInt() != 1;
+        }
+        return true;
+      })) {
+    llvm::outs() << "containingOp's result yield with stride\n";
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+  // Try to get iter domain position from input position.
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    llvm::outs() << "can't get iter domain position from input position\n";
+    return failure();
+  }
+
+  // Try to get all containing op result's position from iter domain position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(consumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+    if (failed(tileableConsumer.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      llvm::outs()
+          << "can't get result domain position from iter domain position\n";
+      return failure();
+    }
+  }
+
+  // All check passed, try to fuse consumer.
+  // Create tiled implementation of containing op.
+  FailureOr<TilingResult> tileAndFuseResult =
+      tileableConsumer.getTiledImplementationFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes);
+  if (failed(tileAndFuseResult)) {
+    llvm::outs() << "get tiled implementation failed\n";
+    return failure();
+  }
+
+  auto tiledOps = tileAndFuseResult->tiledOps;
+  if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+    llvm::outs() << "failed to tile consumer op: " << *tileableConsumer << "\n";
+    return failure();
+  }
+
+  // Replace tiled op's operand.
+  for (auto operandNum : operandNums) {
+    tiledOps[0]->setOperand(operandNum, candidateSliceOp.getSource());
+  }
+  rewriter.replaceUsesWithIf(bridge, forOp.getInits()[bridgeResultIdx],
+                             [&](OpOperand &use) {
+                               Operation *op = use.getOwner();
+                               return forOp->isProperAncestor(op);
+                             });
+
+  SmallVector<Value> newOuts(forOp.getInits());
+  newOuts.append(dpsInits);
+
+  // Create new scf.for op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                              forOp.getUpperBound(),
+                                              forOp.getStep(), newOuts);
+  rewriter.eraseBlock(newforOp.getBody());
+  newforOp.getRegion().takeBody(forOp.getRegion());
+
+  for (auto v : dpsInits) {
+    newforOp.getBody()->addArgument(v.getType(), v.getLoc());
+    auto bbArgs = newforOp.getBody()->getArguments();
+    rewriter.replaceUsesWithIf(v, bbArgs.back(), [&](OpOperand &use) {
+      Operation *op = use.getOwner();
+      return newforOp->isProperAncestor(op);
+    });
+  }
+
+  // Fix terminator.
+  scf::YieldOp oldTerminatorOp =
+      static_cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+
+  SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+  rewriter.setInsertionPoint(candidateSliceOp);
+  auto bbArgs = newforOp.getBody()->getArguments();
+  for (auto [idx, v] : llvm::enumerate(tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                      rewriter.getIndexAttr(1));
+    newYieldOperands.push_back(rewriter.create<tensor::InsertSliceOp>(
+        candidateSliceOp->getLoc(), v,
+        bbArgs[1 + forOp.getInits().size() + idx], resultPositions[idx].first,
+        resultPositions[idx].second, strides));
+  }
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+
+  rewriter.eraseOp(oldTerminatorOp);
+
+  // Replace the result of forall and consumer op.
+  for (auto result : llvm::enumerate(forOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforOp->getResult(forOp.getInits().size() + consumerResult.index()));
+  }
+
+  // Need to erase the old for.
+  rewriter.eraseOp(forOp);
+
+  return scf::SCFFuseConsumerOfSliceResult{
+      consumerOp, tiledOps[0]->getResult(0), {}};
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp,
+                                      bool useSCFFor) {
+  if (useSCFFor) {
+    return tileAndFuseConsumerOfSliceSCFFor(
+        rewriter, static_cast<tensor::InsertSliceOp>(candidateSliceOp));
+  }
+  return tileAndFuseConsumerOfSliceSCFForall(
+      rewriter, static_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp));
+}
+
 /// Implementation of fusing producer of a single slice by computing the
 /// slice of the producer in-place.
 std::optional<scf::SCFFuseProducerOfSliceResult>
diff --git a/mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir
new file mode 100644
index 00000000000000..bc9da4979b8efb
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir
@@ -0,0 +1,110 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2) -> (tensor<64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.addf %out, %13 : f32
+          linalg.yield %14 : f32
+        } -> tensor<32xf32>
+      %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      scf.yield %4 : tensor<64xf32>
+    }
+    %in_operand_2 = tensor.empty() : tensor<64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64xf32>
+    %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
+    return %2 : tensor<64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %yield use_for true
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_for(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %0 = tensor.empty() : tensor<64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME:      iter_args(%[[MAT_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[MAT_OUT_ARG]]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[MAT_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      scf.yield %[[INSERT_MAT]], %[[INSERT_ELEM]] :
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+module {
+  func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+      }
+    }
+    %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+    %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
+    return %2 : tensor<64x64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %yield use_for false
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_forall(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall
+// CHECK-SAME:      shared_outs(%[[MAT_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[MAT_OUT_ARG]]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] :
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%arg3, %arg4] [32, 32] [1, 1]
+//      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%arg3, %arg4] [32, 32] [1, 1]
+//      CHECK:      %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]]
+//      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[MAT_OUT_ARG]]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#1 :
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 335db1a61f476e..66d87791cda2a8 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -160,6 +160,59 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
                         : DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestFuseConsumerOp
+//===----------------------------------------------------------------------===//
+
+/// Apply fusing of consumer transformation to all payload ops and store both
+/// the original consumer operation as well as the fused consumer operation.
+template <typename Range>
+static LogicalResult applyFuseConsumer(RewriterBase &rewriter,
+                                       Operation *transformOp,
+                                       Range &&payloadOps, bool useFor,
+                                       TransformResults &transformResults) {
+  SmallVector<Operation *> originalConsumerOps;
+  SmallVector<Operation *> fusedConsumerOps;
+
+  for (Operation *target : payloadOps) {
+    rewriter.setInsertionPoint(target);
+
+    FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+        scf::tileAndFuseConsumerOfSlice(rewriter, target, useFor);
+
+    if (failed(fuseConsumerResults))
+      return failure();
+
+    // Report back the relevant handles to the transform op.
+    originalConsumerOps.push_back(fuseConsumerResults->origConsumer);
+    fusedConsumerOps.push_back(
+        fuseConsumerResults->tiledAndFusedConsumer.getDefiningOp());
+  }
+
+  transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
+  transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
+                                     TransformResults &transformResults,
+                                     TransformState &state) {
+  LogicalResult result = applyFuseConsumer(rewriter, getOperation(),
+                                           state.getPayloadOps(getTarget()),
+                                           getUseFor(), transformResults);
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseConsumerOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  producesHandle(getConsumer(), effects);
+  producesHandle(getFusedConsumer(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // TestTileUsingForallOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index ef42375e5286d8..70fdc1d338ebbb 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -49,6 +49,27 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
   }];
 }
 
+def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Fuses the consumer of the operation pointed to by the target handle
+    using the options provided as attributes.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$target,
+        DefaultValuedAttr<BoolAttr, "true">:$use_for);
+  let results = (outs TransformHandleTypeInterface:$consumer,
+                      TransformHandleTypeInterface:$fused_consumer);
+
+  let assemblyFormat = [{
+    $target (`use_for` $use_for^)? attr-dict 
+    `:` functional-type(operands, results)
+  }];
+}
+
 def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,



More information about the Mlir-commits mailing list