[Mlir-commits] [mlir] ea75511 - [mlir][TilingInterface] Enable tile and fuse using TilingInterface.

Mahesh Ravishankar llvmlistbot at llvm.org
Tue Jun 21 09:47:24 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-06-21T16:47:14Z
New Revision: ea75511319d9dff8c38c8794c3949c40b63a38d7

URL: https://github.com/llvm/llvm-project/commit/ea75511319d9dff8c38c8794c3949c40b63a38d7
DIFF: https://github.com/llvm/llvm-project/commit/ea75511319d9dff8c38c8794c3949c40b63a38d7.diff

LOG: [mlir][TilingInterface] Enable tile and fuse using TilingInterface.

This patch implements tile and fuse transformation for ops that
implement the tiling interface. To do so,
- `TilingInterface` needs a new method that generates a tiled
  implementation of the operation based on the tile of the result
  needed.
- A pattern is added that replaces a `tensor.extract_slice` whose
  source is defined by an operation that implements the
  `TilingInterface` with a tiled implementation that produces the
  extracted slice in-place (using the method added to
  `TilingInterface`).
- A pattern is added that takes a sequence of operations that
  implement the `TilingInterface` (for now `LinalgOp`s), tiles the
  consumer, and greedily fuses its producers iteratively.

Differential Revision: https://reviews.llvm.org/D127809

Added: 
    mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
    mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 6e8af767ff8a3..1f3ee8a5b27f6 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -10,9 +10,12 @@
 #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
 
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/TilingInterface.h"
 
+#include <deque>
+
 namespace mlir {
 class Operation;
 class PatternRewriter;
@@ -55,7 +58,7 @@ struct SCFTilingResult {
   SmallVector<scf::ForOp> loops;
 };
 
-/// Pattern to tile an op that implementas the `TilingInterface` using
+/// Pattern to tile an op that implements the `TilingInterface` using
 /// `scf.for` for iterating over the tiles.
 struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
   /// Construct a generic pattern applied to all TilingInterface ops.
@@ -81,6 +84,56 @@ struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
   SCFTilingOptions options;
 };
 
+/// Pattern to tile and fuse a sequence of operations, by tiling the consumer
+/// and fusing its producers. Note that this assumes that it is valid to
+/// tile+fuse the producer into the innermost tiled loop. Its up to the caller
+/// to ensure that the tile sizes provided make this fusion valid.
+///
+/// For example, for the following sequence
+///
+/// ```mlir
+/// %0 = linalg.fill ...
+/// %1 = linalg.matmul ... outs(%0 : ...) ...
+/// ```
+///
+/// it is legal to fuse the fill with the matmul only if the matmul is tiled
+/// along the parallel dimensions and not the reduction dimension, i.e. the tile
+/// size for the reduction dimension should be 0.
+struct SCFTileAndFuseResult {
+  SmallVector<Operation *> tiledAndFusedOps;
+  SmallVector<scf::ForOp> loops;
+};
+struct TileConsumerAndFuseProducersUsingSCFForOp
+    : public OpInterfaceRewritePattern<TilingInterface> {
+
+  /// Construct a generic pattern applied to all TilingInterface ops.
+  TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
+                                            SCFTilingOptions options,
+                                            PatternBenefit benefit = 1);
+
+  /// Construct a generic pattern applied to `opName`.
+  TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
+                                            MLIRContext *context,
+                                            SCFTilingOptions options,
+                                            PatternBenefit benefit = 1);
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  FailureOr<SCFTileAndFuseResult>
+  returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
+
+private:
+  /// This pattern uses the tiling pattern. Instead of using inheritance, use
+  /// the patterns as private object that is instantiated at the same time as
+  /// this pattern.
+  TileUsingSCFForOp tilingPattern;
+};
+
 } // namespace scf
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index e6267e9cf02e5..28c22aecdf318 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
@@ -20,6 +21,14 @@ namespace tensor {
 void populateSplitPaddingPatterns(RewritePatternSet &patterns,
                                   PatternBenefit baseBenefit = 1);
 
+/// Pattern to swap an `tensor.extract_slice` with its producer when the
+/// producer implements the `TilingInterface`. The pattern itself does not
+/// provide a mechanism to control where the application happens. With use of
+/// transform dialect that control is done within the transform dialect. Other
+/// use cases can inherit from this pattern and add necessary controls.
+FailureOr<Value> replaceExtractSliceWithTiledProducer(
+    OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
+
 } // namespace tensor
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 606901375ede8..f3fdc30168b28 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -120,7 +120,48 @@ def TilingInterface : OpInterface<"TilingInterface"> {
         /*defaultImplementation=*/[{
           return failure();
         }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the code that produces a tile of the result.
+
+          Generates the IR that computes the tile of a result of the
+          operation.  The `offsets` and `sizes` describe the tile of
+          the output required. This is 
diff erent from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the tile of the
+          result required. This method enables fusion by using tile
+          and fuse. The method returns failure if the operation can't be
+          tiled to generate the result tile. In practical terms this
+          implies it cannot be tiled and fused with its consumers.        
+
+          - `dest` are the Value into which the result of the tiled
+            operation is to be inserted into. The type of the `dest`
+            Values is same as the types returned by
+            `getDestinationOperands` method.
+          - `offsets` provides the offset of the tile within the
+            iteration space
+          - `sizes` provides the size of the tile.
+          - `tileDestOperands` specifies whether to also tile `dest` operands
+            or not. Avoiding tiling `dest` operands can be useful for 
+            composition with various looping container ops.
+        }],
+        /*retType=*/"FailureOr<Value>",
+        /*methodName=*/"generateResultTileValue",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$resultNumber,
+          "ValueRange":$dest,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes,
+          "bool":$tileDestOperands),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
       >
-  ];
+  ];  
 }
 #endif // MLIR_TILINGINTERFACE

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c67097ab3d695..88b21f15081f3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -30,7 +30,6 @@ template <typename LinalgOpTy>
 struct LinalgOpTilingInterface
     : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
                                             LinalgOpTy> {
-
   /// Return the destination operands.
   SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
     return llvm::cast<LinalgOp>(op).getOutputOperands();
@@ -47,6 +46,8 @@ struct LinalgOpTilingInterface
 
   /// Return the iteration domain range.
   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(op);
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
     auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
@@ -129,16 +130,65 @@ struct LinalgOpTilingInterface
     resultSizes = sliceOp.getMixedSizes();
     return success();
   }
+
+  FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
+                                           unsigned resultNumber,
+                                           ValueRange dest,
+                                           ArrayRef<OpFoldResult> offsets,
+                                           ArrayRef<OpFoldResult> sizes,
+                                           bool tileDestOperands) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the output is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the result to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return op->emitOpError(
+          "unhandled tiled implementation generation when result is not "
+          "accessed using a permuted projection");
+    }
+
+    auto numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
+        iterationTileSizes(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (auto range : llvm::enumerate(iterationDomain)) {
+        iterationTileOffsets[range.index()] = range.value().offset;
+        iterationTileSizes[range.index()] = range.value().size;
+      }
+    }
+    for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition =
+          resultExpr.value().cast<AffineDimExpr>().getPosition();
+      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
+      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
+    }
+
+    SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
+        b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
+    if (tiledOp.size() != 1)
+      return op->emitOpError("failed to generate tiled implementation");
+
+    return tiledOp[0]->getResult(resultNumber);
+  }
 };
 
 } // namespace
 
-template <typename OpType> static void registerOne(MLIRContext *ctx) {
+template <typename OpType>
+static void registerOne(MLIRContext *ctx) {
   OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
 }
 
 /// Variadic helper function.
-template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
+template <typename... OpTypes>
+static void registerAll(MLIRContext *ctx) {
   // FIXME: In c++17 this can be simplified by using 'fold expressions'.
   (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
 }

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 4646abcf3e8d1..1bad67f3d7f4d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -42,6 +42,10 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
   return *this;
 }
 
+//===----------------------------------------------------------------------===//
+// TileUsingSCFForOp pattern implementation.
+//===----------------------------------------------------------------------===//
+
 /// Generate an empty loop nest that represents the tiled loop nest shell.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
@@ -247,3 +251,155 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
   rewriter.replaceOp(op, tilingResult.loops.front().getResults());
   return tilingResult;
 }
+
+//===----------------------------------------------------------------------===//
+// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
+//===----------------------------------------------------------------------===//
+
+scf::TileConsumerAndFuseProducersUsingSCFForOp::
+    TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
+                                              scf::SCFTilingOptions options,
+                                              PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      tilingPattern(context, std::move(options)) {}
+
+scf::TileConsumerAndFuseProducersUsingSCFForOp::
+    TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
+                                              MLIRContext *context,
+                                              scf::SCFTilingOptions options,
+                                              PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      tilingPattern(context, std::move(options)) {}
+
+/// Return the `Value` that is defined by an operation that implements
+/// the `TilingInterface`. Looks through `iter_args` of scf.for nest
+/// if required.
+static Optional<OpResult> getFusableProducer(Value v) {
+  while (auto blockArg = v.dyn_cast<BlockArgument>()) {
+    auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
+    if (!loopOp)
+      return llvm::None;
+    v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
+  }
+  if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
+    return llvm::None;
+  return v.cast<OpResult>();
+}
+
+FailureOr<scf::SCFTileAndFuseResult>
+scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
+    TilingInterface op, PatternRewriter &rewriter) const {
+  // This transformation is only valid for ops that return values (i.e. not
+  // valid to use with operations that have memref operands).
+  if (!op->getNumResults()) {
+    return rewriter.notifyMatchFailure(
+        op, "invalid pattern for op with no results");
+  }
+
+  // 1. First tile the consumer.
+  SCFTileAndFuseResult tileAndFuseResult;
+  {
+    FailureOr<SCFTilingResult> tilingResult =
+        tilingPattern.returningMatchAndRewrite(op, rewriter);
+    if (failed(tilingResult)) {
+      return failure();
+    }
+    tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
+    tileAndFuseResult.loops = std::move(tilingResult->loops);
+  }
+
+  // 2. Typically, the operands of the tiled operation are slices of the
+  //    operands of the untiled operation. These are expressed in IR using
+  //    `tensor.extract_slice` operations with source being the operands of the
+  //    untiled operation. Create a worklist of these `tensor.extract_slice`
+  //    operations. If the producers of the source of the `tensor.extract_slice`
+  //    can be tiled such that the tiled value is generated in-place, that
+  //    effectively tiles + fuses the operations.
+  auto addCandidateSlices = [](Operation *fusedOp,
+                               std::deque<tensor::ExtractSliceOp> &candidates) {
+    for (Value operand : fusedOp->getOperands())
+      if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+        candidates.push_back(sliceOp);
+  };
+
+  std::deque<tensor::ExtractSliceOp> candidates;
+  addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
+  OpBuilder::InsertionGuard g(rewriter);
+  while (!candidates.empty()) {
+    // 2a. Traverse the slices in BFS fashion.
+    tensor::ExtractSliceOp candidateSliceOp = candidates.front();
+    candidates.pop_front();
+
+    // 2b. Get the producer of the source (potentially walking through
+    // `iter_args` of nested `scf.for`)
+    Optional<OpResult> fusableProducer =
+        getFusableProducer(candidateSliceOp.source());
+    if (!fusableProducer)
+      continue;
+
+    // 2c. Generate the tiled implementation of the producer of the source
+    rewriter.setInsertionPoint(candidateSliceOp);
+    FailureOr<Value> fusedProducerValue =
+        tensor::replaceExtractSliceWithTiledProducer(
+            rewriter, candidateSliceOp, fusableProducer.getValue());
+    if (failed(fusedProducerValue))
+      continue;
+    rewriter.replaceOp(candidateSliceOp, fusedProducerValue.getValue());
+
+    // 2d. The operands of the fused producer might themselved be slices of
+    //     values produced by operations that implement the `TilingInterface`.
+    //     Add these operations to the worklist.
+    Operation *fusedProducer = fusedProducerValue->getDefiningOp();
+    tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
+    addCandidateSlices(fusedProducer, candidates);
+
+    // 2e. If the operation being fused creates a value that is used as `outs`
+    //     in the tiled operation, the result of the unfused operation will be
+    //     used in the `iter_args` of the tiled loop generated. When the
+    //     operation is fused, this use in `iter_args` needs to be modified to
+    //     use the destination of the fused operation. For example, starting
+    //     with
+    //
+    //     ```mlir
+    //     %0 = linalg.init_tensor ...
+    //     %1 = linalg.fill ... outs(%0:...)...
+    //     %2 = linalg.matmul ... outs(%1:...)....
+    //     ```
+    //
+    //     First the `linalg.matmul` gets tiled
+    //
+    //     ```mlir
+    //     %0 = linalg.init_tensor
+    //     %1 = linalg.fill
+    //     %2 = scf.for .... iter_args(%arg0 = %1)...
+    //        ...
+    //        ... = linalg.matmul ...
+    //
+    //     ```
+    //
+    //     When the `linalg.fill` gets fused, the `iter_args` needs to be
+    //     modified
+    //
+    //     ```mlir
+    //     %0 = linalg.init_tensor
+    //     %1 = scf.for ... iter_args(%arg0 = %0)...
+    //        ...
+    //        %2 = linalg.fill ...
+    //        %3 = linalg.matmul ... outs(%2: ...)...
+    //     ```
+    TilingInterface unfusedProducerOp =
+        cast<TilingInterface>(fusableProducer->getOwner());
+    scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
+    SmallVector<Value> unfusedProducerOpDestValues =
+        unfusedProducerOp.getDestinationOperands(rewriter);
+    for (OpOperand &uses : unfusedProducerOp->getUses()) {
+      if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
+        unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
+        unsigned operandNumber = uses.getOperandNumber();
+        outerMostTiledLoop->setOperand(
+            operandNumber, unfusedProducerOpDestValues[resultNumber]);
+      }
+    }
+  }
+  return tileAndFuseResult;
+}

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index f4983c4d5c886..8479c43211e83 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   SplitPadding.cpp
+  SwapExtractSliceWithProducer.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
@@ -18,5 +19,6 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRPass
   MLIRSCFDialect
   MLIRTensorDialect
+  MLIRTilingInterface
   MLIRTransforms
   )

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp
new file mode 100644
index 0000000000000..8d570cfdf7592
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp
@@ -0,0 +1,43 @@
+//===- SwapExtractSliceWithProducer.cpp - Swapping `tensor.extract_slice` ---=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Swap a `tensor.extract_slice` with the producer of the source if the producer
+// implements the `TilingInterface`. When used in conjunction with tiling this
+// effectively tiles + fuses the producer with its consumer.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+using namespace mlir;
+
+FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
+    OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
+  auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
+  if (!producerOp)
+    return failure();
+
+  // `TilingInterface` currently only supports strides being 1.
+  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+        return !isConstantIntValue(ofr, 1);
+      }))
+    return failure();
+
+  FailureOr<Value> tiledResult = producerOp.generateResultTileValue(
+      builder, producer.getResultNumber(),
+      producerOp.getDestinationOperands(builder), sliceOp.getMixedOffsets(),
+      sliceOp.getMixedSizes(), true);
+  if (failed(tiledResult))
+    return failure();
+
+  return tiledResult.getValue();
+}

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
new file mode 100644
index 0000000000000..dd77211d8ccc6
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -0,0 +1,185 @@
+// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s
+
+func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %gemm : tensor<?x?xf32>
+}
+//      CHECK: func.func @gemm_fill_fusion(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor
+//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT]])
+//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
+// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
+//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
+//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
+// CHECK-SAME:           outs(%[[INIT_TILE]] :
+//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:           outs(%[[FILL_TILE]] :
+//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
+//      CHECK        scf.yield %[[INSERT]]
+
+// -----
+
+func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm = linalg.matmul
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %generic = linalg.generic {
+      __internal_linalg_transform__ = "fusion",
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32 
+  } -> tensor<?x?xf32>
+  return %generic : tensor<?x?xf32> 
+}
+//      CHECK: func.func @gemm_generic_fusion(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>)
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor
+//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT]])
+//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
+// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
+//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
+//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
+// CHECK-SAME:           outs(%[[INIT_TILE]] :
+//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:           outs(%[[FILL_TILE]] :
+//  CHECK-DAG:       %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]]
+//  CHECK-DAG:       %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
+//      CHECK:       %[[GENERIC_TILE:.+]] = linalg.generic
+// CHECK-SAME:           ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
+// CHECK-SAME:           outs(%[[OUTS_TILE]] :
+//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
+//      CHECK        scf.yield %[[INSERT]]
+
+// -----
+
+func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
+  %init0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm0 = linalg.matmul
+      ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
+  %init1 = linalg.init_tensor [%d0, %d2] : tensor<?x?xf32>
+  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm1 = linalg.matmul  {__internal_linalg_transform__ = "gemm_fusion"}
+      ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %gemm1 : tensor<?x?xf32>
+}
+//      CHECK: func.func @gemm_gemm_fusion(
+// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]]
+//  CHECK-DAG:   %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+//  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]]
+//      CHECK:   %[[INIT1:.+]] = linalg.init_tensor [%[[D0]], %[[D2]]]
+//      CHECK:   scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       iter_args(%[[ITERARG:.+]] = %[[INIT1]])
+//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
+//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]], 0]
+//      CHECK:     %[[FILL0_TILE:.+]] = linalg.fill
+// CHECK-SAME:         outs(%[[INIT0_TILE]] :
+//      CHECK:     %[[GEMM0_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME:         outs(%[[FILL0_TILE]] :
+//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
+//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
+//      CHECK:     %[[FILL1_TILE:.+]] = linalg.fill
+// CHECK-SAME:         outs(%[[INIT1_TILE]] :
+//      CHECK:     %[[GEMM1_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
+// CHECK-SAME:         outs(%[[FILL1_TILE]] :
+//      CHECK:     %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0]
+//      CHECK      scf.yield %[[INSERT]]
+
+// -----
+
+func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %init0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm = linalg.matmul
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %init1 = linalg.init_tensor [%d1, %d0] : tensor<?x?xf32>
+  %transpose = linalg.generic {
+      __internal_linalg_transform__ = "fusion",
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32):
+      linalg.yield %b0 : f32 
+  } -> tensor<?x?xf32>
+  return %transpose : tensor<?x?xf32>
+}
+//      CHECK: func.func @gemm_transpose_fusion(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//  CHECK-DAG:   %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+//  CHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]]
+//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT1]])
+//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
+// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
+//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+//  CHECK-DAG:       %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV0]], %[[IV1]]]
+//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
+// CHECK-SAME:           outs(%[[INIT0_TILE]] :
+//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:           outs(%[[FILL_TILE]] :
+//  CHECK-DAG:       %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
+//      CHECK:       %[[GENERIC_TILE:.+]] = linalg.generic
+// CHECK-SAME:           ins(%[[GEMM_TILE]] :
+// CHECK-SAME:           outs(%[[OUTS_TILE]] :
+//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
+//      CHECK        scf.yield %[[INSERT]]

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 1e094329db66f..a7367a713ff4f 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -split-input-file %s | FileCheck %s
 
 func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
     %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index f3ba7a1c5f52d..6241603d6a679 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -29,8 +29,9 @@ using namespace mlir;
 
 namespace {
 
-/// Construct a generic pattern applied to all TilingInterface ops that verify
-/// `filter`.
+/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
+/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
+/// using a `filter` to avoid recursive application.
 struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
   TestTileUsingSCFForOpWithFilter(MLIRContext *context,
                                   scf::SCFTilingOptions options,
@@ -52,8 +53,7 @@ struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
     if (failed(filter.checkAndNotify(rewriter, op)))
       return failure();
 
-    FailureOr<scf::SCFTilingResult> tilingResult =
-        returningMatchAndRewrite(op, rewriter);
+    auto tilingResult = returningMatchAndRewrite(op, rewriter);
     if (failed(tilingResult)) {
       return failure();
     }
@@ -65,6 +65,50 @@ struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
   linalg::LinalgTransformationFilter filter;
 };
 
+/// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern
+/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
+/// ops for iterating over the tiles) while using a `filter` to avoid recursive
+/// application.
+struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter
+    : public scf::TileConsumerAndFuseProducersUsingSCFForOp {
+  TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
+      MLIRContext *context, scf::SCFTilingOptions options,
+      linalg::LinalgTransformationFilter filter =
+          linalg::LinalgTransformationFilter(),
+      PatternBenefit benefit = 1)
+      : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options,
+                                                       benefit),
+        filter(filter) {}
+
+  /// Construct a generic pattern applied to `opName`.
+  TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
+      StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
+      linalg::LinalgTransformationFilter filter =
+          linalg::LinalgTransformationFilter(),
+      PatternBenefit benefit = 1)
+      : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options,
+                                                       benefit),
+        filter(filter) {}
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(filter.checkAndNotify(rewriter, op)))
+      return failure();
+
+    auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter);
+    if (failed(tileAndFuseResult)) {
+      return failure();
+    }
+    filter.replaceLinalgTransformationFilter(
+        rewriter, tileAndFuseResult->tiledAndFusedOps.front());
+    return success();
+  }
+
+private:
+  linalg::LinalgTransformationFilter filter;
+};
+
+/// Test pass for testing the use of `TilingInterface`.
 struct TestTilingInterfacePass
     : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
@@ -82,29 +126,63 @@ struct TestTilingInterfacePass
     return "Test tiling using TilingInterface";
   }
 
+  Option<bool> testTiling{
+      *this, "tile-using-scf-for",
+      llvm::cl::desc(
+          "Test tiling using TilingInterface with scf.for operations"),
+      llvm::cl::init(false)};
+
+  Option<bool> testTileConsumerAndFuseProducer{
+      *this, "tile-consumer-and-fuse-producer-using-scf-for",
+      llvm::cl::desc("Test tile and fuse transformation using TilingInterface "
+                     "with scf.for operations"),
+      llvm::cl::init(false)};
+
   void runOnOperation() override;
+
+private:
+  void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns);
 };
 } // namespace
 
-static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
-  auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
-                                 StringRef filterName) {
-    scf::SCFTilingOptions tilingOptions;
-    tilingOptions.setTileSizes(tileSizes);
-    linalg::LinalgTransformationFilter filter(
-        StringAttr::get(context, filterName),
-        StringAttr::get(context, "tiled"));
-    patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
-                                                  filter);
-  };
-  // 1. Tiling M and N dims of `linalg.matmul` on tensors.
-  addPatternForTiling({10, 20}, "simple_gemm");
-  // 2. Tiling M, N and K of `linalg.matmul` on buffers.
-  addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
-  // 3. Tiling 3D parallel generic op which implements a transpose
-  addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
-  // 4. Tiling 2D conv op.
-  addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
+template <class Pattern>
+static void
+addPatternForTiling(MLIRContext *context, ArrayRef<int64_t> tileSizes,
+                    StringRef filterName, RewritePatternSet &patterns) {
+  scf::SCFTilingOptions tilingOptions;
+  tilingOptions.setTileSizes(tileSizes);
+  linalg::LinalgTransformationFilter filter(
+      StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
+  patterns.add<Pattern>(context, tilingOptions, filter);
+}
+
+void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
+                                              RewritePatternSet &patterns) {
+  if (testTiling) {
+    // 1. Tiling M and N dims of `linalg.matmul` on tensors.
+    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
+        context, {10, 20}, "simple_gemm", patterns);
+    // 2. Tiling M, N and K of `linalg.matmul` on buffers.
+    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
+        context, {10, 20, 30}, "simple_gemm_memref", patterns);
+    // 3. Tiling 3D parallel generic op which implements a transpose
+    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
+        context, {10, 0, 20}, "parallel_generic_transpose", patterns);
+    // 4. Tiling 2D conv op.
+    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
+        context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns);
+    return;
+  }
+  if (testTileConsumerAndFuseProducer) {
+    // 1. Tile and fuse of gemm with bias-add operation.
+    addPatternForTiling<
+        TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
+        context, {10, 20}, "fusion", patterns);
+    addPatternForTiling<
+        TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
+        context, {10}, "gemm_fusion", patterns);
+    return;
+  }
 }
 
 void TestTilingInterfacePass::runOnOperation() {

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8ef7787306916..f0813db443a5c 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1881,6 +1881,7 @@ cc_library(
         ":SCFUtils",
         ":Support",
         ":TensorDialect",
+        ":TensorTransforms",
         ":TilingInterface",
         ":Transforms",
         "//llvm:Support",
@@ -5028,6 +5029,7 @@ cc_library(
         ":SCFDialect",
         ":TensorDialect",
         ":TensorPassIncGen",
+        ":TilingInterface",
         ":Transforms",
         "//llvm:Support",
     ],


        


More information about the Mlir-commits mailing list