[Mlir-commits] [mlir] cf6a7c1 - [mlir][TilingInterface] Add pattern to tile using TilingInterface and implement TilingInterface for Linalg ops.

Mahesh Ravishankar llvmlistbot at llvm.org
Mon Jun 13 13:37:59 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-06-13T20:37:44Z
New Revision: cf6a7c1947931df20cf294213c8edf2d8c0490f6

URL: https://github.com/llvm/llvm-project/commit/cf6a7c1947931df20cf294213c8edf2d8c0490f6
DIFF: https://github.com/llvm/llvm-project/commit/cf6a7c1947931df20cf294213c8edf2d8c0490f6.diff

LOG: [mlir][TilingInterface] Add pattern to tile using TilingInterface and implement TilingInterface for Linalg ops.

This patch adds support for tiling operations that implement the
TilingInterface.
- It separates the loop constructs that are used to iterate over tile
  from the implementation of the tiling itself. For example, the use
  of destructive updates is more related to use of scf.for for
  iterating over tiles that are tensors.
- To test the transformation, TilingInterface is implemented for
  LinalgOps. The separation of the looping constructs used from the
  implementation of tile code generation greatly simplifies the
  latter.
- The implementation of TilingInterface for LinalgOp is kept as an
  external model for now till this approach can be fully flushed out
  to replace the existing tiling + fusion approaches in Linalg.

Differential Revision: https://reviews.llvm.org/D127133

Added: 
    mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h
    mlir/include/mlir/Dialect/SCF/TileUsingInterface.h
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
    mlir/test/lib/Interfaces/CMakeLists.txt
    mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/SCF/Utils/Utils.h
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/test/lib/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h
new file mode 100644
index 0000000000000..5b88f1d05ce84
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- TilingInterfaceImpl.h - Implementation of TilingInterface ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
+#define MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerTilingInterfaceExternalModels(DialectRegistry &registry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 3e9d07209017c..36b143b2d2ce8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -164,11 +164,11 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
                                       ValueRange ivs, ValueRange tileSizes);
 
-/// Compute tile sizes, given a list of loop `ivs`, `tileSizes` and dimension
+/// Compute tile sizes, given a list of `tileSizes` and dimension
 /// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the
 /// corresponding result size is the corresponding value from `sizeBounds`.
 /// Note: The returned tile sizes are closed intervals.
-SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
+SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
                                     ValueRange tileSizes,
                                     ArrayRef<Value> sizeBounds);
 

diff  --git a/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h
new file mode 100644
index 0000000000000..25911cecdc1e5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h
@@ -0,0 +1,87 @@
+//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
+#define MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+namespace mlir {
+class Operation;
+class PatternRewriter;
+class TilingInterface;
+} // namespace mlir
+
+namespace mlir {
+namespace scf {
+
+using SCFTileSizeComputationFunction =
+    std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
+
+/// Options to use to control tiling.
+struct SCFTilingOptions {
+  /// Computation function that returns the tile sizes for each operation.
+  /// Delayed construction of constant tile sizes should occur to interoperate
+  /// with folding.
+  SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
+
+  SCFTilingOptions &
+  setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
+    tileSizeComputationFunction = std::move(fun);
+    return *this;
+  }
+  /// Set the `tileSizeComputationFunction` to return the values `ts`. The
+  /// values must not fold away when tiling. Otherwise, use a more robust
+  /// `tileSizeComputationFunction`.
+  SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
+    tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
+    return *this;
+  }
+  /// Convenience function to set the `tileSizeComputationFunction` to a
+  /// function that computes tile sizes at the point they are needed. Allows
+  /// proper interaction with folding.
+  SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
+};
+
+struct SCFTilingResult {
+  Operation *tiledOp;
+  SmallVector<scf::ForOp> loops;
+};
+
+/// Pattern to tile an op that implementas the `TilingInterface` using
+/// `scf.for` for iterating over the tiles.
+struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
+  /// Construct a generic pattern applied to all TilingInterface ops.
+  TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options,
+                    PatternBenefit benefit = 1);
+
+  /// Construct a generic pattern applied to `opName`.
+  TileUsingSCFForOp(StringRef opName, MLIRContext *context,
+                    SCFTilingOptions options, PatternBenefit benefit = 1);
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  FailureOr<SCFTilingResult>
+  returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
+
+private:
+  /// Options to control tiling;
+  SCFTilingOptions options;
+};
+
+} // namespace scf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H

diff  --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index ebd055cad4ee8..3c75754d64125 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SCF_UTILS_UTILS_H_
 #define MLIR_DIALECT_SCF_UTILS_UTILS_H_
 
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -32,12 +33,6 @@ class CallOp;
 class FuncOp;
 } // namespace func
 
-namespace scf {
-class IfOp;
-class ForOp;
-class ParallelOp;
-} // namespace scf
-
 /// Replace the `loop` with `newIterOperands` added as new initialization
 /// values. `newYieldValuesFn` is a callback that can be used to specify
 /// the additional values to be yielded by the loop. The number of
@@ -57,6 +52,25 @@ scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
                                     ValueRange newIterOperands,
                                     const NewYieldValueFn &newYieldValuesFn);
 
+/// Update a perfectly nested loop nest to yield new values from the innermost
+/// loop and propagating it up through the loop nest. This function
+/// - Expects `loopNest` to be a perfectly nested loop with outer most loop
+///   first and innermost loop last.
+/// - `newIterOperands` are the initialization values to be used for the
+///    outermost loop
+/// - `newYielValueFn` is the callback that generates the new values to be
+///   yielded from within the innermost loop.
+/// - The original loops are not erased,  but are left in a "no-op" state where
+///   the body of the loop just yields the basic block arguments that correspond
+///   to the initialization values of a loop. The original loops are dead after
+///   this method.
+/// - All uses of the `newIterOperands` within the generated new loop
+///   are replaced with the corresponding `BlockArgument` in the loop body.
+SmallVector<scf::ForOp>
+replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
+                             ValueRange newIterOperands,
+                             NewYieldValueFn newYieldValueFn);
+
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 6346899b39981..606901375ede8 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -98,6 +98,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
         /*defaultImplementation=*/[{
           return {};
         }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of the result tile computed by the tiled operation.
+
+          Specifies what tile of the result of the original tensor is computed
+          by the tiled implementation. Expects the same `offsets` and `sizes` as
+          used to obtain the tiled implementation of the operation.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getResultTilePosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$resultNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$resultOffsets,
+          "SmallVector<OpFoldResult> &":$resultSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
       >
   ];
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index fc17fba490aa1..cf771861ff580 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   SparseTensorRewriting.cpp
   SplitReduction.cpp
   Tiling.cpp
+  TilingInterfaceImpl.cpp
   Transforms.cpp
   Vectorization.cpp
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index dfc78977c560e..bb4760588bc8e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -320,8 +320,7 @@ static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op,
         // Compute offsets and sizes of ExtractSliceOp.
         SmallVector<Value> offsets =
             computeTileOffsets(b, loc, localIvs, tileSizes);
-        SmallVector<Value> sizes =
-            computeTileSizes(b, loc, localIvs, tileSizes, allDims);
+        SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims);
         // Create ExtractSliceOp: Extract a tile from the tensor::PadOp.
         // Note: The tensor::PadOp is located outside of the loop nest. It is
         // later moved inside by ExtractSliceOfPadTensorSwapPattern.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
new file mode 100644
index 0000000000000..c67097ab3d695
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -0,0 +1,156 @@
+//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// External model implementation of TilingInterface for LinalgOps. An external
+/// model implementation is used for now till the use of `TilingInterface` is
+/// on-par with the current Linalg tiling + fusion patterns. Once it is
+/// maybe possible to move this into the op-definition (though there are
+/// advantages to leaving it as an external model)
+template <typename LinalgOpTy>
+struct LinalgOpTilingInterface
+    : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
+                                            LinalgOpTy> {
+
+  /// Return the destination operands.
+  SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+    return llvm::cast<LinalgOp>(op).getOutputOperands();
+  }
+
+  /// Return the loop iterator type.
+  SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+    LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
+    return llvm::to_vector(
+        llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
+          return strAttr.cast<StringAttr>().getValue();
+        }));
+  }
+
+  /// Return the iteration domain range.
+  SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+    auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
+    AffineMap map = linalgOp.getShapesToLoopsMap();
+    Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+    return llvm::to_vector(llvm::map_range(
+        applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
+          return Range{zero, v, one};
+        }));
+  }
+
+  // Instantiate the tiled implementation of the operation.
+  SmallVector<Operation *>
+  getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
+                         ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes,
+                         bool tileDestOperands) const {
+    // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
+    // specified could lead to out of bounds accesses.
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+    SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+        b, loc, linalgOp, valuesToTile,
+        getValueOrCreateConstantIndexOp(b, loc, offsets),
+        getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
+
+    SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
+        linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
+          return tiledOperands[opOperand->getOperandNumber()].getType();
+        }));
+
+    Operation *tiledOp =
+        linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
+
+    return {tiledOp};
+  }
+
+  // Return the details of the output tile generated by the tiled
+  // implementation.
+  LogicalResult
+  getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
+                        ArrayRef<OpFoldResult> offsets,
+                        ArrayRef<OpFoldResult> sizes,
+                        SmallVector<OpFoldResult> &resultOffsets,
+                        SmallVector<OpFoldResult> &resultSizes) const {
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+
+    AffineExpr d0;
+    bindDims(b.getContext(), d0);
+
+    auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
+                                               AffineExpr expr,
+                                               ValueRange operands) -> Value {
+      AffineMap map = AffineMap::inferFromExprList({expr}).front();
+      SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
+      mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
+      canonicalizeMapAndOperands(&map, &normalizedOperands);
+      return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
+    };
+
+    SmallVector<Value> sizeVals =
+        getValueOrCreateConstantIndexOp(b, loc, sizes);
+    SmallVector<Value> subShapeSizes =
+        llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
+          return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
+        }));
+    OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
+    Value sliceOpResult =
+        makeTiledShape(b, loc, outOperand->get(), sizeVals,
+                       linalgOp.getTiedIndexingMap(outOperand),
+                       getValueOrCreateConstantIndexOp(b, loc, offsets),
+                       /*ubs*/ {}, subShapeSizes, true);
+    auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
+    if (!sliceOp)
+      return failure();
+    resultOffsets = sliceOp.getMixedOffsets();
+    resultSizes = sliceOp.getMixedSizes();
+    return success();
+  }
+};
+
+} // namespace
+
+template <typename OpType> static void registerOne(MLIRContext *ctx) {
+  OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
+  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+  (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
+}
+
+#define GET_OP_LIST
+
+void mlir::linalg::registerTilingInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
+    registerOne<linalg::GenericOp>(ctx);
+    registerAll<
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+        >(ctx);
+  });
+}

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index bfd2c68cfa68a..bf684344387b3 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -893,7 +893,7 @@ SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
   return offsets;
 }
 
-SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
+SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
                                     ValueRange tileSizes,
                                     ArrayRef<Value> sizeBounds) {
   SmallVector<Value> sizes;
@@ -923,7 +923,7 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
   // that define tile subshapes.
   SmallVector<Value> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
   SmallVector<Value> subShapeSizes =
-      computeTileSizes(b, loc, ivs, tileSizes, sizeBounds);
+      computeTileSizes(b, loc, tileSizes, sizeBounds);
 
   assert(static_cast<int64_t>(valuesToTile.size()) ==
              linalgOp.getNumInputsAndOutputs() &&

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 8f5322dc7b9da..c876c9071269c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
   StructuralTypeConversions.cpp
+  TileUsingInterface.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
new file mode 100644
index 0000000000000..0f71d5288932f
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -0,0 +1,249 @@
+//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the tiling using TilingInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/TileUsingInterface.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tile-using-interface"
+
+using namespace mlir;
+
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
+  assert(!tileSizeComputationFunction && "tile sizes already set");
+  SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
+  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
+    OpBuilder::InsertionGuard guard(b);
+    b.setInsertionPointToStart(
+        &op->getParentOfType<func::FuncOp>().getBody().front());
+    return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
+      Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
+      return v;
+    }));
+  };
+  return *this;
+}
+
+/// Generate an empty loop nest that represents the tiled loop nest shell.
+/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
+/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
+/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
+/// the
+///   tile processed within the inner most loop.
+static SmallVector<scf::ForOp>
+generateTileLoopNest(OpBuilder &builder, Location loc,
+                     ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
+                     SmallVector<OpFoldResult> &offsets,
+                     SmallVector<OpFoldResult> &sizes) {
+  assert(!loopRanges.empty() && "expected at least one loop range");
+  assert(loopRanges.size() == tileSizeVals.size() &&
+         "expected as many tile sizes as loop ranges");
+  OpBuilder::InsertionGuard guard(builder);
+  SmallVector<scf::ForOp> loops;
+  offsets.resize(loopRanges.size());
+  sizes.resize(loopRanges.size());
+
+  // The tile size to use (to avoid out of bounds access) is  minimum of
+  // `tileSize` and `ub - iv`, where `iv` is the induction variable
+  // of the tiled loop.
+  AffineExpr s0, s1, d0;
+  bindDims(builder.getContext(), d0);
+  bindSymbols(builder.getContext(), s0, s1);
+  AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
+
+  for (auto loopRange : llvm::enumerate(loopRanges)) {
+    // No loops if tile size is zero. Set offset and size to the loop
+    // offset and size.
+    if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
+      offsets[loopRange.index()] = loopRange.value().offset;
+      sizes[loopRange.index()] = loopRange.value().size;
+      continue;
+    }
+
+    auto loop = builder.create<scf::ForOp>(
+        loc, loopRange.value().offset, loopRange.value().size,
+        tileSizeVals[loopRange.index()], ValueRange{},
+        [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+            ValueRange /*iterArgs*/) {
+          Value boundedTileSize = builder.create<AffineMinOp>(
+              bodyLoc, minMap,
+              ValueRange{iv, tileSizeVals[loopRange.index()],
+                         loopRange.value().size});
+          sizes[loopRange.index()] = boundedTileSize;
+          builder.create<scf::YieldOp>(loc);
+        });
+    offsets[loopRange.index()] = loop.getInductionVar();
+    loops.push_back(loop);
+    builder.setInsertionPoint(loop.getBody()->getTerminator());
+  }
+  return loops;
+}
+
+scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
+                                          scf::SCFTilingOptions options,
+                                          PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      options(std::move(options)) {}
+
+scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
+                                          MLIRContext *context,
+                                          scf::SCFTilingOptions options,
+                                          PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      options(std::move(options)) {}
+
+FailureOr<scf::SCFTilingResult>
+scf::TileUsingSCFForOp::returningMatchAndRewrite(
+    TilingInterface op, PatternRewriter &rewriter) const {
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointAfter(op);
+
+  if (!options.tileSizeComputationFunction) {
+    return rewriter.notifyMatchFailure(
+        op, "missing tile size computation function");
+  }
+
+  // 1. Get the range of the loops that are represented by the operation.
+  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
+  size_t numLoops = iterationDomain.size();
+  if (numLoops == 0) {
+    return rewriter.notifyMatchFailure(
+        op, "unable to tile op with no iteration domain");
+  }
+
+  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
+  // skips tiling a particular dimension. This convention is significantly
+  // simpler to handle instead of adjusting affine maps to account for missing
+  // dimensions.
+  SmallVector<Value, 4> tileSizeVector =
+      options.tileSizeComputationFunction(rewriter, op);
+  if (tileSizeVector.size() < iterationDomain.size()) {
+    auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+    tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
+  }
+
+  scf::SCFTilingResult tilingResult;
+  SmallVector<OpFoldResult> offsets, sizes;
+  {
+    // 3. Materialize an empty loop nest that iterates over the tiles. These
+    // loops for now do not return any values even if the original operation has
+    // results.
+    tilingResult.loops = generateTileLoopNest(
+        rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
+
+    LLVM_DEBUG({
+      if (!tilingResult.loops.empty()) {
+        llvm::errs() << "LoopNest shell :\n";
+        tilingResult.loops.front().dump();
+        llvm::errs() << "\n";
+      }
+    });
+
+    // 4. Generate the tiled implementation within the inner most loop.
+    if (!tilingResult.loops.empty())
+      rewriter.setInsertionPoint(
+          tilingResult.loops.back().getBody()->getTerminator());
+    SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
+        rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
+    if (tiledImplementation.size() != 1) {
+      return rewriter.notifyMatchFailure(
+          op, "expected tiled implementation to return a single op");
+    }
+    tilingResult.tiledOp = tiledImplementation[0];
+
+    LLVM_DEBUG({
+      if (!tilingResult.loops.empty()) {
+        llvm::errs() << "After tiled implementation :\n";
+        tilingResult.loops.front().dump();
+        llvm::errs() << "\n";
+      }
+    });
+  }
+
+  if (op->getNumResults() == 0) {
+    rewriter.eraseOp(op);
+    return tilingResult;
+  }
+
+  // 5. If the original operations has results, modify the loop nest to yield
+  // the replacement values.
+  SmallVector<Value> replacements;
+  if (tilingResult.loops.empty()) {
+    // 5a. If there were no loops, the tiled implementation results are the
+    // replacements.
+    rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
+    return tilingResult;
+  }
+
+  // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
+  // replacement values using destructive updates. Use the `TilingInterface`
+  // to get the position of the result tiles and use that to generate the
+  // destructive update pattern, i.e.,
+  //
+  // ```mlir
+  // scf.for %iv0 = ... {
+  //   %0 = tiled_op
+  // }
+  // ```
+  //
+  // is transformed to
+  //
+  // ```mlir
+  // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
+  //   %0 = tiled_op
+  //   %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
+  //   scf.yield %1
+  // }
+  // ```
+  NewYieldValueFn yieldValueFn =
+      [&](OpBuilder &b, Location loc,
+          ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
+    SmallVector<Value> yieldedValues;
+    Attribute one = b.getIndexAttr(1);
+    for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
+      SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
+      if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
+                                          resultTileOffsets,
+                                          resultTileSizes))) {
+        op.emitOpError("unable to get position of result ")
+            << resultNum << " of the tiled implementation";
+        return {};
+      }
+      SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
+                                                  one);
+      Value yieldedValue = b.create<tensor::InsertSliceOp>(
+          op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
+          newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
+          resultTileStrides);
+      yieldedValues.push_back(yieldedValue);
+    }
+    return yieldedValues;
+  };
+  SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
+      rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
+      yieldValueFn);
+  for (auto loop : llvm::enumerate(tilingResult.loops)) {
+    rewriter.eraseOp(loop.value());
+    tilingResult.loops[loop.index()] = newLoops[loop.index()];
+  }
+  rewriter.replaceOp(op, tilingResult.loops.front().getResults());
+  return tilingResult;
+}

diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index bce73bd3c432d..2ffe2e955d1bc 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -23,6 +23,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
 
 using namespace mlir;
 
@@ -101,6 +102,31 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
   return newLoop;
 }
 
+SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
+    OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
+    ValueRange newIterOperands, NewYieldValueFn newYieldValueFn) {
+  if (loopNest.empty())
+    return {};
+  SmallVector<scf::ForOp> newLoopNest(loopNest.size());
+
+  newLoopNest.back() = replaceLoopWithNewYields(
+      builder, loopNest.back(), newIterOperands, newYieldValueFn);
+
+  for (unsigned loopDepth :
+       llvm::reverse(llvm::seq<unsigned>(0, loopNest.size() - 1))) {
+    NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc,
+                             ArrayRef<BlockArgument> innerNewBBArgs) {
+      SmallVector<Value> newYields(
+          newLoopNest[loopDepth + 1]->getResults().take_back(
+              newIterOperands.size()));
+      return newYields;
+    };
+    newLoopNest[loopDepth] = replaceLoopWithNewYields(
+        builder, loopNest[loopDepth], newIterOperands, fn);
+  }
+  return newLoopNest;
+}
+
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
new file mode 100644
index 0000000000000..1e094329db66f
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -0,0 +1,194 @@
+// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s
+
+func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//      CHECK: func.func @simple_matmul(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
+// CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
+//      CHECK:     %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]]
+//      CHECK:     %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
+// CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
+//      CHECK:       %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]]
+//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:           [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
+//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME:           [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT1]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:           outs(%[[INIT_TILE]] :
+//      CHECK:       %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT1]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+//      CHECK:       scf.yield %[[UPDATE]]
+//      CHECK:     scf.yield %[[INNER]]
+//      CHECK:   return %[[OUTER]]
+
+// -----
+
+func.func @simple_matmul_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+    %arg2 : memref<?x?xf32>) {
+  linalg.matmul {__internal_linalg_transform__ = "simple_gemm_memref"}
+      ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+//      CHECK: func.func @simple_matmul_memref(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
+//  CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
+//      CHECK:     %[[TS_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]]
+//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
+//      CHECK:       %[[TS_N:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]]
+//      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
+//      CHECK:         %[[TS_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[K]]]
+//  CHECK-DAG:         %[[LHS_TILE:.+]] = memref.subview %[[ARG0]]
+// CHECK-SAME:             [%[[IV0]], %[[IV2]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
+//  CHECK-DAG:         %[[RHS_TILE:.+]] = memref.subview %[[ARG1]]
+// CHECK-SAME:             [%[[IV2]], %[[IV1]]] [%[[TS_K]], %[[TS_N]]] [1, 1]
+//  CHECK-DAG:         %[[OUT_TILE:.+]] = memref.subview %[[ARG2]]
+// CHECK-SAME:             [%[[IV0]], %[[IV1]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:         linalg.matmul
+// CHECK-SAME:             ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:             outs(%[[OUT_TILE]] :
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+  %init0 = linalg.init_tensor [128, 300, 200] : tensor<128x300x200xf32>
+  %init1 = linalg.init_tensor [300, 128, 200] : tensor<300x128x200xf32>
+  %0:2 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      {__internal_linalg_transform__ = "parallel_generic_transpose"}
+      ins(%arg0 : tensor<128x200x300xf32>)
+      outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      linalg.yield %b0, %b0 : f32, f32
+    } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
+  return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//      CHECK: func.func @multi_result(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+//  CHECK-DAG:   %[[C300:.+]] = arith.constant 300 : index
+//  CHECK-DAG:   %[[INIT0:.+]] = linalg.init_tensor [128, 300, 200]
+//  CHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [300, 128, 200]
+//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
+// CHECK-SAME:       iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+//      CHECK:     %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C128]]]
+//      CHECK:     %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
+// CHECK-SAME:         iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
+//      CHECK:       %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C300]]]
+//  CHECK-DAG:       %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:           [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, %[[TS_X]]] [1, 1, 1]
+//  CHECK-DAG:       %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
+//  CHECK-DAG:       %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
+// CHECK-SAME:           [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
+//      CHECK:       %[[RESULT_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME:           ins(%[[ARG_TILE]] :
+// CHECK-SAME:           outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+//      CHECK:       %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
+//      CHECK:       %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
+// CHECK-SAME:           [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
+//      CHECK:       scf.yield %[[UPDATE0]], %[[UPDATE1]]
+//      CHECK:     scf.yield %[[INNER]]#0, %[[INNER]]#1
+//      CHECK:   return %[[OUTER]]#0, %[[OUTER]]#1
+
+// -----
+
+func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+    %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_hwcf {
+      strides = dense<[2, 3]> : tensor<2xi64>,
+      dilation = dense<[4, 5]> : tensor<2xi64>,
+      __internal_linalg_transform__ = "simple_conv"}
+      ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)>
+//      CHECK: func.func @conv2D(
+// CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:     %[[FILTER:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
+//  CHECK-DAG:   %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
+//  CHECK-DAG:   %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
+//  CHECK-DAG:   %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]]
+//  CHECK-DAG:   %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
+//  CHECK-DAG:   %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
+//  CHECK-DAG:   %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
+//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]]
+// CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[INIT]])
+//      CHECK:     %[[TS_P:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[P]]]
+//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]]
+// CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
+//      CHECK:       %[[TS_Q:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[Q]]]
+//      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]]
+// CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
+//  CHECK-DAG:         %[[TS_C:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[C]]]
+//  CHECK-DAG:         %[[TS_H:.+]] = affine.apply #[[MAP3]](%[[TS_P]])[%[[R]]]
+//  CHECK-DAG:         %[[TS_W:.+]] = affine.apply #[[MAP4]](%[[TS_Q]])[%[[S]]]
+//  CHECK-DAG:         %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]]
+// CHECK-SAME:             [0, %[[IV0]], %[[IV1]], %[[IV2]]] [%[[N]], %[[TS_H]], %[[TS_W]], %[[TS_C]]]
+//  CHECK-DAG:         %[[FILTER_TILE:.+]] = tensor.extract_slice %[[FILTER]]
+// CHECK-SAME:             [%[[IV0]], %[[IV1]], %[[IV2]], 0] [%[[TS_P]], %[[TS_Q]], %[[TS_C]], %[[F]]]
+//  CHECK-DAG:         %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT2]]
+// CHECK-SAME:             [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
+//      CHECK:         %[[CONV_TILE:.+]] = linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME:             dilation = dense<[4, 5]> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME:             ins(%[[INPUT_TILE]], %[[FILTER_TILE]] :
+// CHECK-SAME:             outs(%[[INIT_TILE]] :
+//      CHECK:         tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]]
+// CHECK-SAME:             [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]

diff  --git a/mlir/test/lib/CMakeLists.txt b/mlir/test/lib/CMakeLists.txt
index 97149dcf38d8e..88e55e77a3fb9 100644
--- a/mlir/test/lib/CMakeLists.txt
+++ b/mlir/test/lib/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(Analysis)
 add_subdirectory(Conversion)
 add_subdirectory(Dialect)
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
 add_subdirectory(Pass)
 add_subdirectory(Reducer)

diff  --git a/mlir/test/lib/Interfaces/CMakeLists.txt b/mlir/test/lib/Interfaces/CMakeLists.txt
new file mode 100644
index 0000000000000..4a0567ab46423
--- /dev/null
+++ b/mlir/test/lib/Interfaces/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(TilingInterface)

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt
new file mode 100644
index 0000000000000..437e39c30b7be
--- /dev/null
+++ b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_library(MLIRTilingInterfaceTestPasses
+  TestTilingInterface.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRAffine
+  MLIRArithmetic
+  MLIRLinalg
+  MLIRLinalgTransforms
+  MLIRMemRef
+  MLIRSCF
+  MLIRSCFTransforms
+  MLIRTensor
+  )

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
new file mode 100644
index 0000000000000..c3795c32f32bb
--- /dev/null
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -0,0 +1,126 @@
+//===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing tiling operations using
+// `TilingInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/TileUsingInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Construct a generic pattern applied to all TilingInterface ops that verify
+/// `filter`.
+struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
+  TestTileUsingSCFForOpWithFilter(MLIRContext *context,
+                                  scf::SCFTilingOptions options,
+                                  linalg::LinalgTransformationFilter filter =
+                                      linalg::LinalgTransformationFilter(),
+                                  PatternBenefit benefit = 1)
+      : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
+
+  /// Construct a generic pattern applied to `opName`.
+  TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
+                                  scf::SCFTilingOptions options,
+                                  linalg::LinalgTransformationFilter filter =
+                                      linalg::LinalgTransformationFilter(),
+                                  PatternBenefit benefit = 1)
+      : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(filter.checkAndNotify(rewriter, op)))
+      return failure();
+
+    FailureOr<scf::SCFTilingResult> tilingResult =
+        returningMatchAndRewrite(op, rewriter);
+    if (failed(tilingResult)) {
+      return failure();
+    }
+    filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
+    return success();
+  }
+
+private:
+  linalg::LinalgTransformationFilter filter;
+};
+
+struct TestTilingInterfacePass
+    : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
+
+  TestTilingInterfacePass() = default;
+  TestTilingInterfacePass(const TestTilingInterfacePass &pass)
+      : PassWrapper(pass) {}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
+                    tensor::TensorDialect>();
+    linalg::registerTilingInterfaceExternalModels(registry);
+  }
+  StringRef getArgument() const final { return "test-tiling-interface"; }
+  StringRef getDescription() const final {
+    return "Test tiling using TilingInterface";
+  }
+
+  void runOnOperation() override;
+};
+} // namespace
+
+static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
+  auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
+                                 StringRef filterName) {
+    scf::SCFTilingOptions tilingOptions;
+    tilingOptions.setTileSizes(tileSizes);
+    linalg::LinalgTransformationFilter filter(
+        StringAttr::get(context, filterName),
+        StringAttr::get(context, "tiled"));
+    patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
+                                                  filter);
+  };
+  // 1. Tiling M and N dims of `linalg.matmul` on tensors.
+  addPatternForTiling({10, 20}, "simple_gemm");
+  // 2. Tiling M, N and K of `linalg.matmul` on buffers.
+  addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
+  // 3. Tiling 3D parallel generic op which implements a transpose
+  addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
+  // 4. Tiling 2D conv op.
+  addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
+}
+
+void TestTilingInterfacePass::runOnOperation() {
+  MLIRContext *context = &getContext();
+
+  RewritePatternSet tilingPatterns(context);
+  addTestPatterns(context, tilingPatterns);
+  if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                          std::move(tilingPatterns))))
+    return signalPassFailure();
+}
+
+namespace mlir {
+namespace test {
+void registerTestTilingInterface() {
+  PassRegistration<TestTilingInterfacePass>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index a8172b83f1a47..97b082e83e5da 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -33,6 +33,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestRewrite
     MLIRTestTransformDialect
     MLIRTestTransforms
+    MLIRTilingInterfaceTestPasses
     MLIRVectorTestPasses
     )
 endif()

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index aa94294b4ea8d..b50cfa964290f 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -111,6 +111,7 @@ void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
 void registerTestSliceAnalysisPass();
 void registerTestTensorTransforms();
+void registerTestTilingInterface();
 void registerTestTransformDialectInterpreterPass();
 void registerTestVectorLowerings();
 } // namespace test
@@ -206,6 +207,7 @@ void registerTestPasses() {
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestTensorTransforms();
+  mlir::test::registerTestTilingInterface();
   mlir::test::registerTestTransformDialectInterpreterPass();
   mlir::test::registerTestVectorLowerings();
 }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 5fde5e771e538..49c08f7c01535 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1864,6 +1864,7 @@ cc_library(
         "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
         "include/mlir/Dialect/SCF/Passes.h",
         "include/mlir/Dialect/SCF/Patterns.h",
+        "include/mlir/Dialect/SCF/TileUsingInterface.h",
         "include/mlir/Dialect/SCF/Transforms.h",
     ],
     includes = ["include"],
@@ -1883,6 +1884,7 @@ cc_library(
         ":SCFUtils",
         ":Support",
         ":TensorDialect",
+        ":TilingInterface",
         ":Transforms",
         "//llvm:Support",
     ],
@@ -2645,6 +2647,7 @@ cc_library(
         exclude = [
             "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
             "include/mlir/Dialect/SCF/Patterns.h",
+            "include/mlir/Dialect/SCF/TileUsingInterface.h",
             "include/mlir/Dialect/SCF/Transforms.h",
         ],
     ),
@@ -6313,6 +6316,7 @@ cc_binary(
         "//mlir/test:TestSPIRV",
         "//mlir/test:TestShapeDialect",
         "//mlir/test:TestTensor",
+        "//mlir/test:TestTilingInterface",
         "//mlir/test:TestTosaDialect",
         "//mlir/test:TestTransformDialect",
         "//mlir/test:TestTransforms",
@@ -7492,6 +7496,7 @@ cc_library(
         ":TensorTilingInterfaceImpl",
         ":TensorTransforms",
         ":TensorUtils",
+        ":TilingInterface",
         ":TransformUtils",
         ":Transforms",
         ":VectorDialect",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 742e7b610453a..fa89b9c990c0b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -293,6 +293,28 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TestTilingInterface",
+    srcs = glob(["lib/Interfaces/TilingInterface/*.cpp"]),
+    includes = ["lib/Dialect/Test"],
+    deps = [
+        "//llvm:Support",
+        "//mlir:Affine",
+        "//mlir:ArithmeticDialect",
+        "//mlir:FuncDialect",
+        "//mlir:IR",
+        "//mlir:LinalgDialect",
+        "//mlir:LinalgTransforms",
+        "//mlir:MemRefDialect",
+        "//mlir:Pass",
+        "//mlir:SCFDialect",
+        "//mlir:SCFTransforms",
+        "//mlir:TensorDialect",
+        "//mlir:TilingInterface",
+        "//mlir:Transforms",
+    ],
+)
+
 cc_library(
     name = "TestPass",
     srcs = glob(["lib/Pass/*.cpp"]),


        


More information about the Mlir-commits mailing list