[Mlir-commits] [mlir] 1a829d2 - [mlir] Purge linalg.tiled_loop.

Alexander Belyaev llvmlistbot at llvm.org
Mon Feb 28 00:06:50 PST 2022


Author: Alexander Belyaev
Date: 2022-02-28T09:05:18+01:00
New Revision: 1a829d2d06958abf09bb6aff81120959206887f6

URL: https://github.com/llvm/llvm-project/commit/1a829d2d06958abf09bb6aff81120959206887f6
DIFF: https://github.com/llvm/llvm-project/commit/1a829d2d06958abf09bb6aff81120959206887f6.diff

LOG: [mlir] Purge linalg.tiled_loop.

Differential Revision: https://reviews.llvm.org/D119415

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir
    mlir/test/Dialect/Linalg/tile-tensors.mlir
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp
    mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
    mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir
    mlir/test/Dialect/Linalg/tiled-loop-to-scf.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 518a2cfacf2d5..0f896df15119a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -138,290 +138,6 @@ def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
   let hasVerifier = 1;
 }
 
-def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
-     AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<LoopLikeOpInterface>,
-     RecursiveSideEffects,
-     SingleBlockImplicitTerminator<"linalg::YieldOp">
-    ]> {
-  let summary = "Linalg tiled loop operation";
-  let description = [{
-    This is a loop-like operation with additional properties. The arguments
-    also include the input and the output tensors or memrefs and the attributes
-    to specify the iterator types.
-
-    Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
-    to "parallel" type, when it is absent from the custom format.
-
-    Tensor-based version:
-
-    The body region of the loop contains `extract_slice` operations applied to
-    every tensor argument of TiledLoopOp.
-
-    The body region must contain exactly one block that terminates with
-    `linalg.yield` with the operands resulting from `insert_slice` operations.
-
-    Example:
-
-    ```mlir
-    %0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
-        ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
-        outs(%out : tensor<24x64xi8>)
-        iterators("parallel")
-        distribution("block_x") {
-      %lhs_sub = tensor.extract_slice %lhs[%i, 0] [%c4, %c64] [1, 1]
-          : tensor<24x64xi8> to tensor<?x?xi8>
-      %rhs_sub = tensor.extract_slice %rhs[%i, 0] [%c4, %c64] [1, 1]
-          : tensor<24x64xi8> to tensor<?x?xi8>
-      %out_sub = tensor.extract_slice %out[%i, 0] [%c4, %c64] [1, 1]
-          : tensor<24x64xi8> to tensor<?x?xi8>
-
-      %result_sub = linalg.generic ...
-
-      %result = tensor.insert_slice %result_sub into %out[%i, 0][%c4, %c64][1, 1]
-        : tensor<?x?xi8> into tensor<24x64xi8>
-      linalg.yield %result : tensor<24x64xi8>
-    }
-    ```
-
-    MemRef-based version:
-
-    The body region of the loop contains `subview` operations applied to
-    every memref argument of TiledLoopOp.
-
-    The body region must contain exactly one block that terminates with
-    `linalg.yield` with no operands.
-
-    Example:
-
-    ```mlir
-    linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
-        ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>)
-        outs(%out : memref<24x64xi8>)
-        iterators("parallel")
-        distribution("block_x") {
-      %lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1]
-          : memref<24x64xi8> to memref<?x?xi8>
-      %rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1]
-          : memref<24x64xi8> to memref<?x?xi8>
-      %out_sub = subview %out[%i, 0] [%c4, %c64] [1, 1]
-          : memref<24x64xi8> to memref<?x?xi8>
-
-      %result_sub = linalg.generic ...
-      linalg.yield
-    }
-    ```
-  }];
-
-  let arguments = (ins Variadic<Index>:$lowerBound,
-                       Variadic<Index>:$upperBound,
-                       Variadic<Index>:$step,
-                       Variadic<AnyType>:$inputs,
-                       Variadic<AnyShaped>:$outputs,
-                       ArrayAttr:$iterator_types,
-                       OptionalAttr<ArrayAttr>:$distribution_types);
-  let results = (outs Variadic<AnyRankedTensor>:$results);
-  let regions = (region SizedRegion<1>:$region);
-
-  let builders = [
-    OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
-      "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
-      "ArrayAttr":$iteratorTypes, "Optional<ArrayAttr>":$distributionTypes,
-      CArg<"function_ref<void (OpBuilder &, Location, /*ivs=*/ValueRange,"
-        "/*inputs=*/ValueRange, /*outputs=*/ValueRange)>",
-        "nullptr">:$bodyBuilderFn)>,
-    OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
-      "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
-      "ArrayAttr":$iteratorTypes,
-      CArg<"function_ref<void (OpBuilder &, Location, /*ivs=*/ValueRange,"
-        "/*inputs=*/ValueRange, /*outputs=*/ValueRange)>",
-        "nullptr">:$bodyBuilderFn)>,
-  ];
-
-  let extraClassDeclaration = [{
-    /// Number of loops
-    unsigned getNumLoops() { return step().size(); }
-
-    /// Number of input operands
-    unsigned getNumInputs() { return inputs().size(); }
-
-    /// Number of output operands
-    unsigned getNumOutputs() { return outputs().size(); }
-
-    /// Number of operands controlling the loop: lbs, ubs, steps
-    unsigned getNumControlOperands() { return 3 * getNumLoops(); }
-
-    ValueRange getInductionVars() {
-      return getBody()->getArguments().take_front(getNumLoops());
-    }
-    ValueRange getRegionInputArgs() {
-      return getBody()->getArguments().slice(getNumLoops(), inputs().size());
-    }
-    ValueRange getRegionOutputArgs() {
-      return getBody()->getArguments().take_back(outputs().size());
-    }
-
-    void setDistributionTypes(Builder& b, ArrayRef<StringRef> types) {
-      assert(types.size() == getNumLoops() &&
-             "expected distribution type for every dimension");
-      distribution_typesAttr(b.getStrArrayAttr(types));
-    }
-
-    void setLowerBounds(ValueRange lowerBounds) {
-      unsigned numLoops = getNumLoops();
-      assert(lowerBounds.size() == numLoops &&
-             "expected lower bounds for every loop dimension");
-      for (unsigned i = 0; i < numLoops; ++i)
-        setOperand(i, lowerBounds[i]);
-    }
-
-    void setUpperBounds(ValueRange upperBounds) {
-      unsigned numLoops = getNumLoops();
-      assert(upperBounds.size() == numLoops &&
-             "expected upper bounds for every loop dimension");
-      for (unsigned i = 0, pos = numLoops; i < numLoops; ++i, ++pos)
-        setOperand(pos, upperBounds[i]);
-    }
-
-    void setSteps(ValueRange steps) {
-      unsigned numLoops = getNumLoops();
-      assert(steps.size() == numLoops &&
-             "expected upper bounds for every loop dimension");
-      for (unsigned i = 0, pos = 2 * numLoops; i < numLoops; ++i, ++pos)
-        setOperand(pos, steps[i]);
-    }
-
-    /// Operand that corresponds to the `bbArg` block argument.
-    OpOperand& getTiedOperand(BlockArgument& bbArg) {
-      return getOperation()->getOpOperand(getNumControlOperands() +
-                                          bbArg.getArgNumber() - getNumLoops());
-    }
-
-    /// Block argument that corresponds to the `input` or `output` operand.
-    BlockArgument getTiedBlockArgument(OpOperand& operand) {
-      auto operandIndex = operand.getOperandNumber();
-      assert(
-          operandIndex >= getNumControlOperands() &&
-          operandIndex < getNumOperands() &&
-          "tied block arg is defined only for `input` and `output` arguments");
-      return getBody()->getArgument(operandIndex - 2 * getNumLoops());
-    }
-
-   /// Result that corresponds to the `outputs` argument of tensor type.
-   OpResult getTiedOpResult(OpOperand& opOperand) {
-      // No result can correspond to a memref argument.
-      if (opOperand.get().getType().isa<MemRefType>()) return OpResult();
-
-      // Check whether the operand index is in bounds of `outputs()` arg.
-      int operandIndex = opOperand.getOperandNumber();
-      int outputIndexStart =
-          getNumControlOperands() + inputs().size();
-      int outputIndexEnd = outputIndexStart + outputs().size();
-      if (operandIndex < outputIndexStart || operandIndex >= outputIndexEnd)
-        return OpResult();
-
-      // Count tensor arguments in `outputs` to compute the result index.
-      int tensorId = -1;
-      for (int i = outputIndexStart; i <= operandIndex; ++i)
-        tensorId += getOperand(i).getType().isa<RankedTensorType>();
-      return getOperation()->getResult(tensorId);
-    }
-
-    /// Append `operand` to the `input` arguments.
-    OpOperand& appendInputOperand(OpBuilder& builder, Value operand) {
-      int numLoops = getNumLoops();
-      int numInputs = getNumInputs();
-      int numOutputs = getNumOutputs();
-
-      getOperation()->insertOperands(getNumControlOperands() + numInputs,
-                                     operand);
-      getBody()->insertArgument(numLoops + numInputs, operand.getType(), 
-                                getLoc());
-      getOperation()->setAttr(
-          TiledLoopOp::getOperandSegmentSizeAttr(),
-          builder.getI32VectorAttr(
-              {numLoops, numLoops, numLoops, numInputs + 1, numOutputs}));
-      return getOperation()->getOpOperand(getNumControlOperands() + numInputs);
-    }
-
-    /// Append `operand` to the `output` arguments.
-    OpOperand& appendOutputOperand(OpBuilder& builder, Value operand) {
-      int numLoops = getNumLoops();
-      int numInputs = getNumInputs();
-      int numOutputs = getNumOutputs();
-
-      getOperation()->insertOperands(
-          getNumControlOperands() + numInputs + numOutputs, operand);
-      getBody()->insertArgument(numLoops + numInputs + numOutputs,
-                                operand.getType(), getLoc());
-      getOperation()->setAttr(
-          TiledLoopOp::getOperandSegmentSizeAttr(),
-          builder.getI32VectorAttr(
-              {numLoops, numLoops, numLoops, numInputs, numOutputs + 1}));
-      return getOperation()->getOpOperand(getNumControlOperands() + numInputs +
-                                          numOutputs);
-    }
-
-    /// Erase `operand` from the `input` or `output` arguments.
-    void eraseOperand(OpBuilder& builder, OpOperand& operand) {
-      int numInputs = getNumInputs();
-      int numLoops = getNumLoops();
-      int numOutputs = getNumOutputs();
-      int numControlOperands = getNumControlOperands();
-
-      int operandIndex = operand.getOperandNumber();
-      assert(operandIndex >= numControlOperands &&
-             operandIndex < static_cast<int>(getNumOperands()) &&
-             "Can erase only `input` or `output` operand");
-
-      if (operandIndex >= numControlOperands + numInputs)
-        --numOutputs;
-      else
-        --numInputs;
-
-      getOperation()->eraseOperand(operandIndex);
-      getBody()->eraseArgument(operandIndex - 2 * numLoops);
-      getOperation()->setAttr(
-          TiledLoopOp::getOperandSegmentSizeAttr(),
-          builder.getI32VectorAttr(
-              {numLoops, numLoops, numLoops, numInputs, numOutputs}));
-    }
-
-    OpOperand* findInputOperand(Value value) {
-      OperandRange::iterator it = llvm::find(inputs(), value);
-      if (it == inputs().end()) return nullptr;
-      return it.getBase();
-    }
-
-    OpOperand* findOutputOperand(Value value) {
-      OperandRange::iterator it = llvm::find(outputs(), value);
-      if (it == outputs().end()) return nullptr;
-      return it.getBase();
-    }
-
-    /// Return whether the op has only MemRef input and outputs.
-    bool hasBufferSemantics() {
-      Operation* op = this->getOperation();
-      return op->getNumResults() == 0 &&
-             llvm::all_of(op->getOpOperands(), [&](OpOperand & operand) {
-               return !operand.get().getType().template isa<ShapedType>() ||
-                      operand.get().getType().template isa<MemRefType>();
-             });
-    }
-
-    /// Return whether the loop dimension is parallel or not.
-    bool isParallelDimension(unsigned dim) {
-      StringAttr attr = this->iterator_types()[dim].cast<StringAttr>();
-      return attr.getValue() == getParallelIteratorTypeName();
-    }
-  }];
-
-  let hasCanonicalizer = 1;
-  let hasCustomAssemblyFormat = 1;
-  let hasFolder = 1;
-  let hasVerifier = 1;
-}
-
 def Linalg_IndexOp : Linalg_Op<"index", [NoSideEffect]>,
     Arguments<(ins Confined<I64Attr, [IntMinValue<0>]>:$dim)>,
     Results<(outs Index:$result)> {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 487362c62e60a..3f8719b0782b5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -31,10 +31,10 @@ std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
 
 std::unique_ptr<Pass> createLinalgNamedOpConversionPass();
 
-std::unique_ptr<OperationPass<FuncOp>> createLinalgTilingPass(
-    ArrayRef<int64_t> tileSizes = {},
-    linalg::LinalgTilingLoopType loopType = linalg::LinalgTilingLoopType::Loops,
-    ArrayRef<StringRef> distributionTypes = {});
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {},
+                       linalg::LinalgTilingLoopType loopType =
+                           linalg::LinalgTilingLoopType::Loops);
 
 std::unique_ptr<OperationPass<FuncOp>>
 createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca);
@@ -42,10 +42,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass();
 
 std::unique_ptr<OperationPass<FuncOp>> createLinalgInlineScalarOperandsPass();
 
-/// Create a pass to convert Linalg tiled loops to `scf.for` and `scf.parallel`
-/// loops and memref.load/memref.store accesses.
-std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgTiledLoopsToSCFPass();
-
 /// Create a pass to convert Linalg operations to scf.for loops and
 /// memref.load/memref.store accesses.
 std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToLoopsPass();

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index dc14011c8fd13..22989386f6b95 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -121,17 +121,6 @@ def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
   let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
 }
 
-def LinalgLowerTiledLoopsToSCF
-    : Pass<"convert-linalg-tiled-loops-to-scf", "FuncOp"> {
-  let summary = "Lower linalg tiled loops to SCF loops and parallel loops";
-  let constructor = "mlir::createConvertLinalgTiledLoopsToSCFPass()";
-  let dependentDialects = [
-    "linalg::LinalgDialect",
-    "scf::SCFDialect",
-    "AffineDialect"
-  ];
-}
-
 def LinalgInlineScalarOperands : Pass<"linalg-inline-scalar-operands", "FuncOp"> {
   let summary = "Inline scalar operands into linalg generic ops";
   let constructor = "mlir::createLinalgInlineScalarOperandsPass()";
@@ -207,12 +196,7 @@ def LinalgTiling : Pass<"linalg-tile", "FuncOp"> {
     ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes",
                "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
     Option<"loopType", "loop-type", "std::string", /*default=*/"\"for\"",
-           "Specify the type of loops to generate: for, parallel or "
-           "tiled_loop">,
-    ListOption<"distributionTypes", "distribution-types", "std::string",
-               "DistributionTypes (if loop-type=tiled_loop)",
-               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
-
+           "Specify the type of loops to generate: for, parallel">
   ];
 }
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 50e6191db5e8a..4c0146056aa72 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -131,9 +131,6 @@ void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
 /// Patterns that are used to inline constant operands into linalg generic ops.
 void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
 
-/// Pattern to convert TiledLoopOp to SCF loops.
-void populateTiledLoopToSCFPattern(RewritePatternSet &patterns);
-
 /// Options that control fusion of elementwise operations.
 struct LinalgElementwiseFusionOptions {
   /// Enable fusion of reshapes into the shape with elementwise operations. By
@@ -1248,13 +1245,6 @@ void populateDecomposeConvolutionPatterns(
     const LinalgTransformationFilter &filter = LinalgTransformationFilter(),
     PatternBenefit benefit = 1);
 
-/// Linalg distribution patterns
-//
-/// Populates `patterns` with patterns to distribute linalg.tiled_loop.
-void populateLinalgDistributeTiledLoopPattern(
-    RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
-    const LinalgTransformationFilter &marker);
-
 //===----------------------------------------------------------------------===//
 // Op-specific patterns.
 //===----------------------------------------------------------------------===//
@@ -1368,31 +1358,6 @@ struct LinalgCopyVTWForwardingPattern
                                 PatternRewriter &rewriter) const override;
 };
 
-/// Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly
-/// into a TiledLoopOp where the step divides the iteration space evenly,
-/// followed by another TiledLoopOp for the last (partial) iteration (if any).
-/// This transformation is called "loop peeling".
-///
-/// This function peels the `idx`-th loop of the TiledLoopOp. To tile all loops
-/// in the loop nest, this function must be called multiple times.
-///
-/// After loop peeling, this function tries to simplify/canonicalize affine.min
-/// and affine.max ops in the body of the two TiledLoopOps. For more details,
-/// refer to `mlir::scf::peelAndCanonicalizeForLoop`.
-///
-/// The return value indicates whether the loop was rewritten or not. Loops are
-/// not rewritten if:
-/// * Loop step size is 1 or
-/// * Loop bounds and step size are static, and step already divides the
-///   iteration space evenly.
-///
-/// Note: This function rewrites the given TiledLoopOp in-place and clones the
-/// TileLoopOp operation for the last iteration. It replaces all uses of the
-/// unpeeled TiledLoopOp with the results of the newly generated TiledLoopOp.
-LogicalResult peelAndCanonicalizeTiledLoop(RewriterBase &rewriter,
-                                           TiledLoopOp loopOp, int64_t idx,
-                                           TiledLoopOp &result);
-
 //===----------------------------------------------------------------------===//
 // Support for staged pattern application.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 44179ebe60757..0a556b3d99eb6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -105,44 +105,6 @@ static LogicalResult foldMemRefCast(Operation *op) {
   return success(folded);
 }
 
-/// This is a specialization of `foldMemRefCast` used for patterns of the form
-/// ```
-///    tiled_loop(memrefcast(%src)) -> tiled_loop(%src)
-/// ```
-/// It folds the source of the memref.cast into the root operation directly.
-static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
-  bool folded = false;
-  Location loc = op->getLoc();
-
-  Block *body = op.getBody();
-  OpBuilder b = OpBuilder::atBlockBegin(body);
-
-  // Update `input` and `output` operands and block arguments if necessary.
-  // Operands list: [lbs, ubs, steps, inputs, outputs].
-  // Block args list: [ivs, inputs, outputs].
-  for (size_t operandIndex = op.getNumControlOperands(),
-              bbArgIndex = op.getNumLoops(), e = op.getNumOperands();
-       operandIndex < e; ++operandIndex, ++bbArgIndex) {
-    OpOperand &operand = op->getOpOperand(operandIndex);
-
-    auto castOp = operand.get().getDefiningOp<memref::CastOp>();
-    if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
-      operand.set(castOp.getOperand());
-      BlockArgument newBbArg = body->insertArgument(
-          bbArgIndex, castOp.getOperand().getType(), op.getLoc());
-      BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1);
-
-      // Insert memref.cast back to the original type.
-      oldBbArg.replaceAllUsesWith(
-          b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg));
-      body->eraseArgument(oldBbArg.getArgNumber());
-
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
 //===----------------------------------------------------------------------===//
 // Region builder helper.
 // TODO: Move this to a utility library.
@@ -1247,630 +1209,9 @@ LogicalResult linalg::YieldOp::verify() {
   if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
     return verifyYield(*this, cast<LinalgOp>(parentOp));
 
-  if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) {
-    // Check if output args with tensor types match results types.
-    SmallVector<Value, 2> tensorOuts;
-    llvm::copy_if(
-        tiledLoopOp.outputs(), std::back_inserter(tensorOuts),
-        [&](Value out) { return out.getType().isa<RankedTensorType>(); });
-    if (tensorOuts.size() != values().size())
-      return emitOpError("expected number of tensor output args = ")
-             << tensorOuts.size()
-             << " to match the number of yield operands = " << values().size();
-
-    TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts));
-    for (auto &item :
-         llvm::enumerate(llvm::zip(tensorTypes, getOperandTypes()))) {
-      Type outType, resultType;
-      unsigned index = item.index();
-      std::tie(outType, resultType) = item.value();
-      if (outType != resultType)
-        return emitOpError("expected yield operand ")
-               << index << " with type = " << resultType
-               << " to match output arg type = " << outType;
-    }
-    return success();
-  }
   return emitOpError("expected parent op with LinalgOp interface");
 }
 
-//===----------------------------------------------------------------------===//
-// TiledLoopOp
-//===----------------------------------------------------------------------===//
-
-void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
-                        ValueRange lowerBounds, ValueRange upperBounds,
-                        ValueRange steps, ValueRange inputs, ValueRange outputs,
-                        ArrayAttr iteratorTypes,
-                        function_ref<void(OpBuilder &, Location, ValueRange,
-                                          ValueRange, ValueRange)>
-                            bodyBuilderFn) {
-  build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs,
-        iteratorTypes, llvm::None, bodyBuilderFn);
-}
-
-void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
-                        ValueRange lowerBounds, ValueRange upperBounds,
-                        ValueRange steps, ValueRange inputs, ValueRange outputs,
-                        ArrayAttr iteratorTypes,
-                        Optional<ArrayAttr> distributionTypes,
-                        function_ref<void(OpBuilder &, Location, ValueRange,
-                                          ValueRange, ValueRange)>
-                            bodyBuilderFn) {
-  result.addOperands(lowerBounds);
-  result.addOperands(upperBounds);
-  result.addOperands(steps);
-  result.addOperands(inputs);
-  result.addOperands(outputs);
-  result.addAttribute(
-      TiledLoopOp::getOperandSegmentSizeAttr(),
-      builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
-                                static_cast<int32_t>(upperBounds.size()),
-                                static_cast<int32_t>(steps.size()),
-                                static_cast<int32_t>(inputs.size()),
-                                static_cast<int32_t>(outputs.size())}));
-  result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
-
-  if (distributionTypes.hasValue())
-    result.addAttribute(getDistributionTypesAttrName(),
-                        distributionTypes.getValue());
-
-  // Add output types for `RankedTensorType` output arguments.
-  for (Value output : outputs) {
-    Type outputType = output.getType();
-    if (outputType.isa<RankedTensorType>())
-      result.addTypes(outputType);
-  }
-
-  OpBuilder::InsertionGuard guard(builder);
-  unsigned numIVs = steps.size();
-  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
-  SmallVector<Location, 8> argLocs(numIVs, result.location);
-  for (Value input : inputs) {
-    argTypes.push_back(input.getType());
-    argLocs.push_back(input.getLoc());
-  }
-  for (Value output : outputs) {
-    argTypes.push_back(output.getType());
-    argLocs.push_back(output.getLoc());
-  }
-  Region *bodyRegion = result.addRegion();
-  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
-
-  if (bodyBuilderFn) {
-    builder.setInsertionPointToStart(bodyBlock);
-    bodyBuilderFn(builder, result.location,
-                  bodyBlock->getArguments().take_front(numIVs),
-                  bodyBlock->getArguments().slice(numIVs, inputs.size()),
-                  bodyBlock->getArguments().take_back(outputs.size()));
-    TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
-  }
-}
-
-void TiledLoopOp::print(OpAsmPrinter &p) {
-  p << " (" << getInductionVars() << ") = (" << lowerBound() << ") to ("
-    << upperBound() << ") step (" << step() << ")";
-
-  if (!inputs().empty()) {
-    p << " ins (";
-    llvm::interleaveComma(llvm::zip(getRegionInputArgs(), inputs()), p,
-                          [&](auto it) {
-                            p << std::get<0>(it) << " = " << std::get<1>(it)
-                              << ": " << std::get<1>(it).getType();
-                          });
-    p << ")";
-  }
-  if (!outputs().empty()) {
-    p << " outs (";
-    llvm::interleaveComma(llvm::zip(getRegionOutputArgs(), outputs()), p,
-                          [&](auto it) {
-                            p << std::get<0>(it) << " = " << std::get<1>(it)
-                              << ": " << std::get<1>(it).getType();
-                          });
-    p << ")";
-  }
-
-  if (llvm::any_of(iterator_types(), [](Attribute attr) {
-        return attr.cast<StringAttr>().getValue() !=
-               getParallelIteratorTypeName();
-      }))
-    p << " iterators" << iterator_types();
-
-  if (distribution_types().hasValue())
-    p << " distribution" << distribution_types().getValue();
-
-  p << ' ';
-  p.printRegion(region(), /*printEntryBlockArgs=*/false);
-  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
-                              TiledLoopOp::getOperandSegmentSizeAttr(),
-                              getIteratorTypesAttrName(),
-                              getDistributionTypesAttrName()});
-}
-
-ParseResult TiledLoopOp::parse(OpAsmParser &parser, OperationState &result) {
-  auto &builder = parser.getBuilder();
-  // Parse an opening `(` followed by induction variables followed by `)`
-  SmallVector<OpAsmParser::OperandType, 4> ivs;
-  if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
-                                     OpAsmParser::Delimiter::Paren))
-    return failure();
-
-  // Parse loop bounds.
-  SmallVector<OpAsmParser::OperandType, 4> lower;
-  if (parser.parseEqual() ||
-      parser.parseOperandList(lower, ivs.size(),
-                              OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(lower, builder.getIndexType(), result.operands))
-    return failure();
-
-  SmallVector<OpAsmParser::OperandType, 4> upper;
-  if (parser.parseKeyword("to") ||
-      parser.parseOperandList(upper, ivs.size(),
-                              OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(upper, builder.getIndexType(), result.operands))
-    return failure();
-
-  // Parse step values.
-  SmallVector<OpAsmParser::OperandType, 4> steps;
-  if (parser.parseKeyword("step") ||
-      parser.parseOperandList(steps, ivs.size(),
-                              OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(steps, builder.getIndexType(), result.operands))
-    return failure();
-
-  // Parse input tensors.
-  SmallVector<OpAsmParser::OperandType, 4> inputs, inputRegionArgs;
-  SmallVector<Type, 4> inputTypes;
-  if (succeeded(parser.parseOptionalKeyword("ins"))) {
-    SMLoc inputsOperandsLoc = parser.getCurrentLocation();
-
-    if (parser.parseAssignmentListWithTypes(inputRegionArgs, inputs,
-                                            inputTypes))
-      return failure();
-
-    if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc,
-                               result.operands))
-      return failure();
-  }
-
-  // Parse output tensors.
-  SmallVector<OpAsmParser::OperandType, 4> outputs, outputRegionArgs;
-  SmallVector<Type, 4> outputTypes;
-  if (succeeded(parser.parseOptionalKeyword("outs"))) {
-    SMLoc outputsOperandsLoc = parser.getCurrentLocation();
-
-    if (parser.parseAssignmentListWithTypes(outputRegionArgs, outputs,
-                                            outputTypes))
-      return failure();
-
-    if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
-                               result.operands))
-      return failure();
-    for (Type outputType : outputTypes)
-      if (outputType.isa<RankedTensorType>())
-        result.addTypes(outputType);
-  }
-
-  // Parse attributes.
-  SmallVector<Attribute, 4> iterTypes, distributionTypes;
-  auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) {
-    if (succeeded(parser.parseOptionalKeyword(keyword))) {
-      StringAttr attr;
-
-      if (parser.parseLSquare() || parser.parseAttribute(attr))
-        return failure();
-      attrs->push_back(attr);
-      for (int i = 1, e = ivs.size(); i < e; ++i) {
-        if (parser.parseComma() || parser.parseAttribute(attr))
-          return failure();
-        attrs->push_back(attr);
-      }
-      if (parser.parseRSquare())
-        return failure();
-    }
-    return success();
-  };
-  if (failed(parseAttr("iterators", &iterTypes)) ||
-      failed(parseAttr("distribution", &distributionTypes)))
-    return failure();
-
-  // Set all loop iterator types to "parallel" if they are not printed in IR.
-  if (iterTypes.empty()) {
-    auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName());
-    iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter);
-  }
-  result.addAttribute(getIteratorTypesAttrName(),
-                      builder.getArrayAttr(iterTypes));
-  if (!distributionTypes.empty())
-    result.addAttribute(getDistributionTypesAttrName(),
-                        builder.getArrayAttr(distributionTypes));
-  result.addAttribute(
-      TiledLoopOp::getOperandSegmentSizeAttr(),
-      builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
-                                static_cast<int32_t>(upper.size()),
-                                static_cast<int32_t>(steps.size()),
-                                static_cast<int32_t>(inputs.size()),
-                                static_cast<int32_t>(outputs.size())}));
-
-  // Parse the body.
-  Region *body = result.addRegion();
-
-  SmallVector<Type, 4> regionTypes(ivs.size(), builder.getIndexType());
-  regionTypes.append(inputTypes);
-  regionTypes.append(outputTypes);
-
-  SmallVector<OpAsmParser::OperandType, 4> regionArgs(ivs);
-  regionArgs.append(inputRegionArgs);
-  regionArgs.append(outputRegionArgs);
-
-  if (parser.parseRegion(*body, regionArgs, regionTypes))
-    return failure();
-
-  // Parse optional attributes.
-  parser.parseOptionalAttrDict(result.attributes);
-
-  return success();
-}
-
-Region &TiledLoopOp::getLoopBody() { return region(); }
-
-LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
-  for (auto *op : ops)
-    op->moveBefore(*this);
-  return success();
-}
-
-bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) {
-  return !region().isAncestor(value.getParentRegion());
-}
-
-LogicalResult TiledLoopOp::verify() {
-  // Check if iterator types are provided for every loop dimension.
-  if (iterator_types().size() != getNumLoops())
-    return emitOpError("expected iterator types array attribute size = ")
-           << iterator_types().size()
-           << " to match the number of loops = " << getNumLoops();
-
-  // Check if types of input arguments match region args types.
-  for (auto &item :
-       llvm::enumerate(llvm::zip(inputs(), getRegionInputArgs()))) {
-    Value input, inputRegionArg;
-    unsigned index = item.index();
-    std::tie(input, inputRegionArg) = item.value();
-    if (input.getType() != inputRegionArg.getType())
-      return emitOpError("expected input arg ")
-             << index << " with type = " << input.getType()
-             << " to match region arg " << index + getNumLoops()
-             << " type = " << inputRegionArg.getType();
-  }
-
-  // Check if types of input arguments match region args types.
-  for (auto &item :
-       llvm::enumerate(llvm::zip(outputs(), getRegionOutputArgs()))) {
-    Value output, outputRegionArg;
-    unsigned index = item.index();
-    std::tie(output, outputRegionArg) = item.value();
-    if (output.getType() != outputRegionArg.getType())
-      return emitOpError("expected output arg ")
-             << index << " with type = " << output.getType()
-             << " to match region arg "
-             << index + getNumLoops() + inputs().size()
-             << " type = " << outputRegionArg.getType();
-  }
-  return success();
-}
-
-namespace {
-
-static constexpr int64_t kNoMatch = -1;
-
-// Folds away TiledLoopOp inputs if they have no uses within the body.
-//
-// Example:
-//
-// %0 = linalg.tiled_loop ...  ins (%in_ = %in: tensor<...>,
-//                                  %in_buf_ = %in_buf: memref<...>) {...}
-// Becomes
-//
-// linalg.tiled_loop ...  ins (%in_buf_ = %in_buf: memref<...>) {...}
-struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
-  using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
-                                PatternRewriter &rewriter) const final {
-    SmallVector<Value, 2> newInputs, regionInputTensorArgs;
-    // Store ids of the corresponding old and new input operands.
-    SmallVector<int64_t, 2> oldInputIdToNew(tiledLoop.inputs().size(),
-                                            kNoMatch);
-    for (const auto &en : llvm::enumerate(
-             llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) {
-      Value in, bbArg;
-      size_t index = en.index();
-      std::tie(in, bbArg) = en.value();
-      if (!bbArg.use_empty()) {
-        oldInputIdToNew[index] = newInputs.size();
-        newInputs.push_back(in);
-      }
-    }
-    if (newInputs.size() == tiledLoop.inputs().size())
-      return failure();
-    Location loc = tiledLoop.getLoc();
-    auto newTiledLoop = rewriter.create<TiledLoopOp>(
-        loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
-        newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(),
-        tiledLoop.distribution_types());
-
-    // Clone the region.
-    BlockAndValueMapping bvm;
-    bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
-    bvm.map(tiledLoop.getRegionOutputArgs(),
-            newTiledLoop.getRegionOutputArgs());
-    for (const auto &en : llvm::enumerate(oldInputIdToNew))
-      if (en.value() != kNoMatch)
-        bvm.map(tiledLoop.getRegionInputArgs()[en.index()],
-                newTiledLoop.getRegionInputArgs()[en.value()]);
-    OpBuilder innerBuilder =
-        OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
-    for (auto &op : *tiledLoop.getBody())
-      innerBuilder.clone(op, bvm);
-    rewriter.replaceOp(tiledLoop, newTiledLoop.getResults());
-
-    return success();
-  }
-};
-
-} // namespace
-
-/// A simple, conservative analysis to determine if the loop is shape
-/// conserving. I.e., the type of the arg-th yielded value is the same as the
-/// type of the corresponding basic block argument of the loop.
-/// Note: This function handles only simple cases. Expand as needed.
-static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
-  auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
-  if (yieldOp.values().empty())
-    // Tiled loop either has no outputs or is a "memref-based version". In
-    // either case, the loop is shape conserving.
-    return true;
-  assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
-         "arg is out of bounds");
-  Value value = yieldOp.values()[arg];
-  while (value) {
-    if (value == loopOp.getRegionOutputArgs()[arg])
-      return true;
-    OpResult opResult = value.dyn_cast<OpResult>();
-    if (!opResult)
-      return false;
-
-    using tensor::InsertSliceOp;
-    value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
-                .template Case<InsertSliceOp>(
-                    [&](InsertSliceOp op) { return op.dest(); })
-                .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
-                  return isShapePreserving(loopOp, opResult.getResultNumber())
-                             ? loopOp.outputs()[opResult.getResultNumber()]
-                             : Value();
-                })
-                .Default([&](auto op) { return Value(); });
-  }
-  return false;
-}
-
-namespace {
-
-/// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block
-/// to dim(y) where `y` is the initial input/output value of the argument.
-///
-/// E.g.:
-/// %y = ... : tensor<...>
-/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
-///   tensor.dim %x, %c0 : tensor<...>
-/// }
-///
-/// is folded to:
-/// %y = ... : tensor<...>
-/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
-///   tensor.dim %y, %c0 : tensor<...>
-/// }
-///
-/// Note: Dim ops are folded only if it can be proven that the runtime type of
-/// the yielded value (in case of outputs) does not change with loop iterations.
-template <typename OpTy>
-struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy dimOp,
-                                PatternRewriter &rewriter) const final {
-    auto src = dimOp.source().template dyn_cast<BlockArgument>();
-    if (!src)
-      return failure();
-    auto loopOp =
-        dyn_cast<TiledLoopOp>(src.getOwner()->getParent()->getParentOp());
-    if (!loopOp)
-      return failure();
-    unsigned numLoops = loopOp.getNumLoops();
-    unsigned numInputArgs = loopOp.getRegionInputArgs().size();
-    if (src.getArgNumber() >= numInputArgs + numLoops &&
-        !isShapePreserving(loopOp,
-                           src.getArgNumber() - numInputArgs - numLoops))
-      return failure();
-
-    auto inputArgs = loopOp.getRegionInputArgs();
-    auto it1 = llvm::find(inputArgs, src);
-    if (it1 != inputArgs.end()) {
-      rewriter.updateRootInPlace(dimOp, [&] {
-        dimOp.sourceMutable().assign(loopOp.inputs()[it1 - inputArgs.begin()]);
-      });
-      return success();
-    }
-
-    auto outputArgs = loopOp.getRegionOutputArgs();
-    auto it2 = llvm::find(outputArgs, src);
-    if (it2 != outputArgs.end()) {
-      rewriter.updateRootInPlace(dimOp, [&] {
-        dimOp.sourceMutable().assign(
-            loopOp.outputs()[it2 - outputArgs.begin()]);
-      });
-      return success();
-    }
-
-    return failure();
-  }
-};
-
-/// Fold dim(r) where `r` is the result of a TiledLoopOp to dim(y) where `y`
-/// is the initial output value of the loop.
-///
-/// E.g.:
-/// %y = ... : tensor<...>
-/// %r = linalg.tiled_loop ... outs(%i = %y : tensor<...>) {
-///   ...
-/// }
-/// %0 = tensor.dim %r, %c0 : tensor<...>
-///
-/// is folded to:
-/// %y = ... : tensor<...>
-/// linalg.tiled_loop ... outs(%i = %y : tensor<...>) {
-///   ...
-/// }
-/// %0 = tensor.dim %y, %c0 : tensor<...>
-///
-/// Note: Dim ops are folded only if it can be proven that the runtime type of
-/// the yielded value (in case of outputs) does not change with loop iterations.
-template <typename OpTy>
-struct DimOfTiledLoopResultFolder : public OpRewritePattern<OpTy> {
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy dimOp,
-                                PatternRewriter &rewriter) const final {
-    auto loopOp = dimOp.source().template getDefiningOp<TiledLoopOp>();
-    if (!loopOp)
-      return failure();
-    auto opResult = dimOp.source().template cast<OpResult>();
-    unsigned resultNumber = opResult.getResultNumber();
-    if (!isShapePreserving(loopOp, resultNumber))
-      return failure();
-    rewriter.updateRootInPlace(dimOp, [&]() {
-      dimOp.sourceMutable().assign(loopOp.outputs()[resultNumber]);
-    });
-    return success();
-  }
-};
-
-// Folds away TiledLoopOp output tensors when the following conditions are met:
-// * result of `linalg.tiled_loop` has no uses
-// * output tensor is the argument of `linalg.yield`
-//
-// Example:
-//
-// %0 = linalg.tiled_loop ...  outs (%o_ = %out: tensor<...>,
-//                                   %obuf_ = %out_buf: memref<...>) {
-//   ...
-//   linalg.yield %o_ : tensor ...
-// }
-//
-// Becomes
-//
-// linalg.tiled_loop ...  outs (%obuf_ = %out_buf: memref<...>) {
-//   ...
-//   linalg.yield
-// }
-struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
-  using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
-                                PatternRewriter &rewriter) const final {
-    if (tiledLoop.getNumResults() == 0)
-      return failure();
-
-    Block *block = tiledLoop.getBody();
-    auto yieldOp = cast<linalg::YieldOp>(block->getTerminator());
-
-    // Match the pattern and collect output buffers that will replace the output
-    // tensors and also the ops that will be ignored when cloning the body.
-    SmallVector<Value, 2> newOutputOperands, newYieldArgs;
-    int resultId = 0;
-    // Store ids of the corresponding old and new output operands.
-    SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(),
-                                             kNoMatch);
-    // Store ids of the corresponding old and new results.
-    SmallVector<int64_t, 2> oldResultIdToNew(tiledLoop.getNumResults(),
-                                             kNoMatch);
-    SmallVector<Value, 2> resultReplacement(tiledLoop.getNumResults());
-    for (const auto &en : llvm::enumerate(
-             llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) {
-      size_t index = en.index();
-      Value out = std::get<0>(en.value());
-      Value outRegionArg = std::get<1>(en.value());
-
-      if (!out.getType().isa<RankedTensorType>()) {
-        oldOutputIdToNew[index] = newOutputOperands.size();
-        newOutputOperands.push_back(out);
-        continue;
-      }
-      Value result = tiledLoop.getResult(resultId);
-      Value yieldArg = yieldOp.getOperand(resultId);
-      if (yieldArg != outRegionArg || !result.use_empty()) {
-        oldOutputIdToNew[index] = newOutputOperands.size();
-        oldResultIdToNew[resultId] = newYieldArgs.size();
-        resultReplacement[resultId] = out;
-        newOutputOperands.push_back(out);
-        newYieldArgs.push_back(yieldArg);
-      }
-      ++resultId;
-    }
-    if (newOutputOperands.size() == tiledLoop.outputs().size())
-      return failure();
-
-    Location loc = tiledLoop.getLoc();
-    auto newTiledLoop = rewriter.create<TiledLoopOp>(
-        loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
-        tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(),
-        tiledLoop.distribution_types());
-
-    // Clone the region.
-    BlockAndValueMapping bvm;
-    bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
-    bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs());
-    for (const auto &en : llvm::enumerate(oldOutputIdToNew)) {
-      if (en.value() != kNoMatch)
-        bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
-                newTiledLoop.getRegionOutputArgs()[en.value()]);
-      else
-        bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
-                tiledLoop.outputs()[en.index()]);
-    }
-    OpBuilder innerBuilder =
-        OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
-    for (auto &op : tiledLoop.getBody()->without_terminator())
-      innerBuilder.clone(op, bvm);
-    innerBuilder.create<linalg::YieldOp>(
-        loc, llvm::to_vector<2>(llvm::map_range(
-                 newYieldArgs, [&](Value arg) { return bvm.lookup(arg); })));
-
-    for (const auto &en : llvm::enumerate(oldResultIdToNew))
-      if (en.value() != kNoMatch)
-        resultReplacement[en.index()] = newTiledLoop.getResult(en.value());
-    rewriter.replaceOp(tiledLoop, resultReplacement);
-
-    return success();
-  }
-};
-} // namespace
-
-void TiledLoopOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                              MLIRContext *context) {
-  results.add<TiledLoopInputsFolder, TiledLoopResultsFolder,
-              DimOfTiledLoopInsOutsFolder<tensor::DimOp>,
-              DimOfTiledLoopInsOutsFolder<memref::DimOp>,
-              DimOfTiledLoopResultFolder<tensor::DimOp>,
-              DimOfTiledLoopResultFolder<memref::DimOp>>(context);
-}
-
-LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
-                                SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCastInTiledLoopOp(*this);
-}
-
 //===----------------------------------------------------------------------===//
 // IndexOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 799c13726091c..b9d1fc2c29f71 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -246,203 +246,6 @@ struct InitTensorOpInterface
   }
 };
 
-/// Bufferization of linalg.tiled_loop. Replace with a new linalg.tiled_loop
-/// that operates entirely on memrefs.
-struct TiledLoopOpInterface
-    : public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface,
-                                                    linalg::TiledLoopOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
-                              const BufferizationState &state) const {
-    auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
-
-    // linalg.tiled_loop operands alone do not bufferize to a memory read, but
-    // one of the uses of their matching bbArgs may.
-    return state.isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand));
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    auto bufferizableOp = cast<BufferizableOpInterface>(op);
-
-    // Only operands with an aliasing OpResult (i.e., output operands) bufferize
-    // to a memory write.
-    return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
-  }
-
-  SmallVector<OpResult>
-  getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                      const BufferizationState &state) const {
-    auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
-
-    // Output operands are tied to their corresponding OpResults.
-    OpResult opResult = tiledLoopOp.getTiedOpResult(opOperand);
-    if (!opResult)
-      return {};
-    return {opResult};
-  }
-
-  BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationState &state) const {
-    return BufferRelation::Equivalent;
-  }
-
-  bool isWritable(Operation *op, Value value,
-                  const BufferizationState &state) const {
-    // Interestingly, linalg::TiledLoopOp's bbArgs can **always** be viewed
-    // inplace from the perspective of nested ops:
-    //   1. Either the matching iter operand is not bufferized inplace and an
-    //      alloc + optional copy makes the bbArg itself inplaceable.
-    //   2. Or the matching iter operand is bufferized inplace and bbArg just
-    //      bufferizes to that too.
-    return true;
-  }
-
-  bool isAllocationHoistingBarrier(Operation *op) const { return true; }
-
-  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationState &state) const {
-    auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
-
-    // Compute new inputs, outputs and results.
-    SmallVector<Value> newInputs, newOutputs, newResults;
-    for (unsigned i = tiledLoopOp.getNumControlOperands();
-         i < tiledLoopOp->getNumOperands(); ++i) {
-      OpOperand &operand = tiledLoopOp->getOpOperand(i);
-      Value rewrittenValue = operand.get();
-      if (rewrittenValue.getType().isa<TensorType>()) {
-        FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, operand);
-        if (failed(bufferOrFailure))
-          return failure();
-        rewrittenValue = *bufferOrFailure;
-      }
-      if (i <
-          tiledLoopOp.getNumControlOperands() + tiledLoopOp.getNumInputs()) {
-        newInputs.push_back(rewrittenValue);
-      } else {
-        newOutputs.push_back(rewrittenValue);
-        if (operand.get().getType().isa<TensorType>())
-          newResults.push_back(rewrittenValue);
-      }
-    }
-
-    // Create new TiledLoopOp.
-    auto newTiledLoopOp = rewriter.create<TiledLoopOp>(
-        tiledLoopOp.getLoc(), tiledLoopOp.lowerBound(),
-        tiledLoopOp.upperBound(), tiledLoopOp.step(), newInputs, newOutputs,
-        tiledLoopOp.iterator_types(), tiledLoopOp.distribution_types());
-
-    // Remove terminator.
-    if (!newTiledLoopOp.getBody()->empty())
-      rewriter.eraseOp(tiledLoopOp.getBody()->getTerminator());
-
-    // Compute new loop body arguments.
-    SmallVector<Value> newBlockArgs, newRegionInOutArgs, oldRegionInOutArgs;
-    ValueRange newInductionVars = newTiledLoopOp.getInductionVars();
-    newBlockArgs.append(newInductionVars.begin(), newInductionVars.end());
-
-    ValueRange newRegionInArgs = newTiledLoopOp.getRegionInputArgs();
-    ValueRange newRegionOutArgs = newTiledLoopOp.getRegionOutputArgs();
-    newRegionInOutArgs.append(newRegionInArgs.begin(), newRegionInArgs.end());
-    newRegionInOutArgs.append(newRegionOutArgs.begin(), newRegionOutArgs.end());
-
-    ValueRange oldRegionInArgs = tiledLoopOp.getRegionInputArgs();
-    ValueRange oldRegionOutArgs = tiledLoopOp.getRegionOutputArgs();
-    oldRegionInOutArgs.append(oldRegionInArgs.begin(), oldRegionInArgs.end());
-    oldRegionInOutArgs.append(oldRegionOutArgs.begin(), oldRegionOutArgs.end());
-    assert(newRegionInArgs.size() == oldRegionInArgs.size() &&
-           "expected same number of input args");
-    assert(newRegionOutArgs.size() == oldRegionOutArgs.size() &&
-           "expected same number of output args");
-
-    for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) {
-      Value oldArg = std::get<0>(it);
-      Value newArg = std::get<1>(it);
-      rewriter.setInsertionPointToStart(newTiledLoopOp.getBody());
-      if (oldArg.getType().isa<TensorType>()) {
-        newBlockArgs.push_back(rewriter.create<bufferization::ToTensorOp>(
-            oldArg.getLoc(), newArg));
-      } else {
-        newBlockArgs.push_back(newArg);
-      }
-    }
-
-    // Move old body into new loop.
-    rewriter.mergeBlocks(tiledLoopOp.getBody(), newTiledLoopOp.getBody(),
-                         newBlockArgs);
-
-    // Replace previous terminator with a new one that does not yield anything.
-    auto oldTerminator =
-        cast<linalg::YieldOp>(newTiledLoopOp.getBody()->getTerminator());
-    rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody());
-    auto newTerminator =
-        rewriter.create<linalg::YieldOp>(oldTerminator->getLoc());
-
-    // Copy buffer of yielded tensor to output buffer. If everything bufferized
-    // inplace, this copy will fold away.
-    rewriter.setInsertionPoint(newTerminator);
-    for (auto it : llvm::zip(oldTerminator.values(), newOutputs)) {
-      Value output = std::get<1>(it);
-      Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
-          newTerminator.getLoc(), output.getType(), std::get<0>(it));
-      if (failed(createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp,
-                              output, state.getOptions())))
-        return failure();
-    }
-
-    // Erase old terminator.
-    rewriter.eraseOp(oldTerminator);
-
-    // Replace results and delete old op.
-    replaceOpWithBufferizedValues(rewriter, op, newResults);
-
-    return success();
-  }
-};
-
-/// Bufferization of linalg.yield. Bufferized as part of linalg.tiled_loop's
-/// bufferization.
-struct YieldOpInterface
-    : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
-                                                    linalg::YieldOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
-                              const BufferizationState &state) const {
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return false;
-  }
-
-  SmallVector<OpResult>
-  getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                      const BufferizationState &state) const {
-    return {};
-  }
-
-  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
-                            const BufferizationState &state) const {
-    // Yield operands always bufferize inplace. Otherwise, an alloc + copy
-    // may be generated inside the block. We should not return/yield allocations
-    // when possible.
-    return true;
-  }
-
-  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationState &state) const {
-    auto yieldOp = cast<linalg::YieldOp>(op);
-
-    if (!yieldOp->getParentOfType<TiledLoopOp>())
-      return yieldOp->emitError(
-          "expected that linalg.yield terminates a tiled_loop");
-
-    assert(yieldOp->getOpOperands().empty() &&
-           "expected that linalg.yield was bufferized together with"
-           " tiled_loop");
-    return success();
-  }
-};
-
 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers
 /// the `BufferizableOpInterface` with each of them.
 template <typename... OpTys>
@@ -701,8 +504,6 @@ LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep(
 void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addOpInterface<linalg::InitTensorOp, InitTensorOpInterface>();
-  registry.addOpInterface<linalg::TiledLoopOp, TiledLoopOpInterface>();
-  registry.addOpInterface<linalg::YieldOp, YieldOpInterface>();
 
   // Register all Linalg structured ops. `LinalgOp` is an interface and it is
   // not possible to attach an external interface to an existing interface.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ec8c8c438635e..a688eb59a6f12 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -4,7 +4,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   CodegenStrategy.cpp
   ComprehensiveBufferizePass.cpp
   Detensorize.cpp
-  Distribution.cpp
   DropUnitDims.cpp
   ElementwiseOpFusion.cpp
   ElementwiseToLinalg.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp b/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp
deleted file mode 100644
index 692df291b2f66..0000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp
+++ /dev/null
@@ -1,87 +0,0 @@
-//===- Distibution.cpp - linalg named ops to generic ops  --------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements the Linalg distibution pass. It updates `tiled_loop`
-// control variables depending on the distribution type.
-//
-//===----------------------------------------------------------------------===//
-//
-#include <utility>
-
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-#define DEBUG_TYPE "linalg-distribution"
-
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-
-using namespace mlir;
-using namespace mlir::linalg;
-
-namespace {
-
-struct DistributeTiledLoopPattern
-    : public OpRewritePattern<linalg::TiledLoopOp> {
-  DistributeTiledLoopPattern(MLIRContext *context,
-                             LinalgLoopDistributionOptions options,
-                             LinalgTransformationFilter marker)
-      : OpRewritePattern<linalg::TiledLoopOp>(context),
-        options(std::move(options)), marker(std::move(marker)) {}
-  LogicalResult matchAndRewrite(linalg::TiledLoopOp op,
-                                PatternRewriter &rewriter) const override {
-    if (failed(marker.checkAndNotify(rewriter, op)))
-      return failure();
-    if (!op.distribution_types().hasValue())
-      return failure();
-
-    Location loc = op.getLoc();
-    SmallVector<Value, 2> newLowerBounds = op.lowerBound();
-    SmallVector<Value, 2> newUpperBounds = op.upperBound();
-    SmallVector<Value, 2> newSteps = op.step();
-
-    // Update bounds and steps.
-    auto distributionTypes = op.distribution_types().getValue();
-    for (int i = 0, e = op.getNumLoops(); i < e; ++i) {
-      StringRef type = distributionTypes[i].cast<StringAttr>().getValue();
-      auto procInfoCallback = options.procInfoMap.find(type);
-      if (procInfoCallback == options.procInfoMap.end())
-        continue;
-
-      if (!isParallelIterator(op.iterator_types()[i])) {
-        op.emitOpError("only support for parallel loops is implemented");
-        return failure();
-      }
-      ProcInfo info = procInfoCallback->second(rewriter, loc);
-      updateBoundsForCyclicDistribution(rewriter, loc, info.procId, info.nprocs,
-                                        newLowerBounds[i], newUpperBounds[i],
-                                        newSteps[i]);
-    }
-    rewriter.updateRootInPlace(op, [&] {
-      op.setLowerBounds(newLowerBounds);
-      op.setUpperBounds(newUpperBounds);
-      op.setSteps(newSteps);
-    });
-    marker.replaceLinalgTransformationFilter(rewriter, op);
-    return success();
-  }
-
-private:
-  LinalgLoopDistributionOptions options;
-  LinalgTransformationFilter marker;
-};
-
-} // namespace
-
-void mlir::linalg::populateLinalgDistributeTiledLoopPattern(
-    RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
-    const LinalgTransformationFilter &marker) {
-  patterns.add<DistributeTiledLoopPattern>(patterns.getContext(), opts, marker);
-}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 32d8ee098bcec..94edb8b630876 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -104,63 +104,8 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
   llvm_unreachable("Expect to be able to extract a shape defining loop range");
 }
 
-// Return tiled operands for the fused producer op. When fusing into
-// `linalg.tiled_loop` one has to update `input` and `output` arguments of the
-// loop correspondingly.
-// Each input tensor of the producer op has to be added to `inputs` of the
-// `tiled_loop` if it is not present there already. Each output tensor has to
-// be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending
-// on whether the correponding result is an input or an output to the loop.
-//
-// NOTE: This way of updating the arguments of the `tiled_loop` assumes that the
-// intermediate result is not used by any other operation but the consumer. A
-// more generic way is to append all missing output tensors of the producer to
-// the tiled loop outputs and hence modify the number of the results, since we
-// would need to add the intermediate results to `linalg.yield`. After that a
-// canonicalization pass would move the unused output args of the `tiled_loop`
-// to the `input` section.
-static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
-  auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
-  if (!tiledLoop)
-    return producer.getInputAndOutputOperands();
-
-  SmallVector<Value> tiledOperands;
-  assert(producer.hasTensorSemantics() &&
-         "only fusion on tensors is currently supported for TiledLinalgOp");
-
-  for (OpOperand *producerInput : producer.getInputOperands()) {
-    OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
-    if (addedInput == nullptr)
-      addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
-    BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
-    tiledOperands.push_back(addedBlockArg);
-  }
-  for (OpOperand *producerOutput : producer.getOutputOperands()) {
-    OpResult result = producer.getTiedOpResult(producerOutput);
-    OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
-    OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
-    assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
-           "The result should be present in `input` or `output` args of "
-           "`tiled_loop");
-
-    bool isInput = resultInputOperand;
-    int opNumber = isInput ? resultInputOperand->getOperandNumber()
-                           : resultOutputOperand->getOperandNumber();
-
-    OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get());
-    if (addedOutput == nullptr)
-      addedOutput =
-          isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get())
-                  : &tiledLoop.appendOutputOperand(b, producerOutput->get());
-
-    OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
-    auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
-    auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand);
-    resultOperandBlockArg.replaceAllUsesWith(addedBlockArg);
-    tiledLoop.eraseOperand(b, resultOperand);
-    tiledOperands.push_back(addedBlockArg);
-  }
-  return tiledOperands;
+static SmallVector<Value> getTiledOperands(LinalgOp producer) {
+  return producer.getInputAndOutputOperands();
 }
 
 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
@@ -198,7 +143,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
 
   // Compute subranges for all tensor input/output operands.
   clonedShapes.append(makeTiledShapes(b, loc, producer,
-                                      getTiledOperands(b, producer), ivs,
+                                      getTiledOperands(producer), ivs,
                                       tileSizes, sizeBounds));
 
   // Iterate over the results in order.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index c9e3c6c955703..5a5554992341f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -260,72 +260,6 @@ class LinalgRewritePattern : public RewritePattern {
   }
 };
 
-/// Converts tiled_loop to SCF loop nests. All parallel dimensions are collected
-/// into an scf.parallel loop and all sequential dimensions will result in the
-/// nested scf.for loop nest. The pattern assumes that a tiled loop with
-/// iterator_types ["reduction", "parallel", "reduction"] can be reordered. It
-/// is true for the tiling that is currently suppported by Linalg.
-struct TiledLoopToSCFPattern : public OpRewritePattern<TiledLoopOp> {
-  using OpRewritePattern<TiledLoopOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TiledLoopOp tiledLoop,
-                                PatternRewriter &rewriter) const override {
-    // Fail conversion if the `tiled_loop` has not been bufferized.
-    if (!tiledLoop.hasBufferSemantics())
-      return failure();
-
-    // Collect loop control parameters for parallel and sequential dimensions.
-    SmallVector<Value, 3> seqLBs, seqUBs, seqSteps, seqIVs;
-    SmallVector<Value, 3> parLBs, parUBs, parSteps, parIVs;
-    for (const auto &en : llvm::enumerate(
-             llvm::zip(tiledLoop.lowerBound(), tiledLoop.upperBound(),
-                       tiledLoop.step(), tiledLoop.getInductionVars()))) {
-      Value lb, ub, step, iv;
-      std::tie(lb, ub, step, iv) = en.value();
-      if (tiledLoop.isParallelDimension(en.index())) {
-        parLBs.push_back(lb);
-        parUBs.push_back(ub);
-        parSteps.push_back(step);
-        parIVs.push_back(iv);
-      } else {
-        seqLBs.push_back(lb);
-        seqUBs.push_back(ub);
-        seqSteps.push_back(step);
-        seqIVs.push_back(iv);
-      }
-    }
-
-    Location loc = tiledLoop.getLoc();
-    auto generateForLoopNestAndCloneBody = [&](OpBuilder &builder, Location loc,
-                                               ValueRange ivs) {
-      BlockAndValueMapping bvm;
-      bvm.map(parIVs, ivs);
-      bvm.map(tiledLoop.getRegionInputArgs(), tiledLoop.inputs());
-      bvm.map(tiledLoop.getRegionOutputArgs(), tiledLoop.outputs());
-
-      // If not all dimensions of the tiled loop are parallel, an scf.for loop
-      // nest is generated.
-      if (!seqIVs.empty()) {
-        scf::LoopNest nest =
-            scf::buildLoopNest(builder, loc, seqLBs, seqUBs, seqSteps,
-                               [&](OpBuilder &builder, Location loc,
-                                   ValueRange ivs) { bvm.map(seqIVs, ivs); });
-        builder.setInsertionPointToStart(nest.loops.back().getBody());
-      }
-      for (auto &op : tiledLoop.getBody()->without_terminator())
-        builder.clone(op, bvm);
-    };
-
-    if (parIVs.empty())
-      generateForLoopNestAndCloneBody(rewriter, loc, llvm::None);
-    else
-      rewriter.create<scf::ParallelOp>(loc, parLBs, parUBs, parSteps,
-                                       generateForLoopNestAndCloneBody);
-    rewriter.eraseOp(tiledLoop);
-    return success();
-  }
-};
-
 /// Local folding pattern for AffineApplyOp that we can apply greedily.
 /// This replaces AffineApplyOp by the proper value in cases where the
 /// associated map is trivial.
@@ -402,136 +336,8 @@ struct LowerToParallelLoops
   }
 };
 
-struct LowerTiledLoopsToSCF
-    : public LinalgLowerTiledLoopsToSCFBase<LowerTiledLoopsToSCF> {
-  void runOnOperation() override {
-    MLIRContext *context = &getContext();
-    RewritePatternSet patterns(context);
-    populateTiledLoopToSCFPattern(patterns);
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-  }
-};
 } // namespace
 
-/// Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly
-/// into two TiledLoopOps: One where the step divides the iteration space
-/// evenly, followed another one for the last (partial) iteration (if any). This
-/// function only rewrites the `idx`-th loop of the loop nest represented by
-/// the TiledLoopOp. To peel the entire loop nest, this function must be called
-/// multiple times.
-///
-/// This function rewrites the given TiledLoopOp in-place and creates a new
-/// TiledLoopOp for the last iteration. It replaces all uses of the original
-/// TiledLoopOp with the results of the newly generated one.
-///
-/// The newly generated TiledLoopOp is returned via `result`. The boundary
-/// at which the loop is split (new upper bound) is returned via `splitBound`.
-/// The return value indicates whether the TiledLoopOp was rewritten or not.
-static LogicalResult peelTiledLoop(RewriterBase &b, TiledLoopOp loopOp,
-                                   int64_t idx, TiledLoopOp &result,
-                                   Value &splitBound) {
-  Value lb = loopOp.lowerBound()[idx], ub = loopOp.upperBound()[idx],
-        step = loopOp.step()[idx];
-  auto ubInt = getConstantIntValue(ub);
-
-  auto loc = loopOp.getLoc();
-  AffineExpr exprLb, exprUb, exprStep;
-  bindSymbols(b.getContext(), exprLb, exprUb, exprStep);
-  // New upper bound: %ub - (%ub - %lb) mod %step
-  auto modMap = AffineMap::get(0, 3, {exprUb - ((exprUb - exprLb) % exprStep)});
-  SmallVector<Value> operands{lb, ub, step};
-  mlir::canonicalizeMapAndOperands(&modMap, &operands);
-  modMap = mlir::simplifyAffineMap(modMap);
-  RewriterBase::InsertionGuard guard(b);
-  b.setInsertionPoint(loopOp);
-  splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, operands);
-  // No specialization necessary if step already divides upper bound evenly.
-  if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound)))
-    return failure();
-
-  // Create remainder loop.
-  b.setInsertionPointAfter(loopOp);
-  auto remainderLoop = cast<TiledLoopOp>(b.clone(*loopOp.getOperation()));
-  loopOp.replaceAllUsesWith(remainderLoop->getResults());
-  // Outputs: Take tensors from main loop's results. Take memrefs from main
-  // loop's outputs.
-  SmallVector<Value> remainderOutputs;
-  for (unsigned o = 0, t = 0; o < loopOp.getNumOutputs(); ++o) {
-    remainderOutputs.push_back(loopOp.outputs()[o].getType().isa<MemRefType>()
-                                   ? loopOp.outputs()[o]
-                                   : loopOp->getResult(t++));
-  }
-  remainderLoop.outputsMutable().assign(remainderOutputs);
-
-  // Set new loop bounds.
-  b.updateRootInPlace(loopOp, [&]() {
-    SmallVector<Value> ubs = loopOp.upperBound();
-    ubs[idx] = splitBound;
-    loopOp.upperBoundMutable().assign(ubs);
-  });
-  SmallVector<Value> lbs = remainderLoop.lowerBound();
-  lbs[idx] = splitBound;
-  remainderLoop.lowerBoundMutable().assign(lbs);
-
-  result = remainderLoop;
-  return success();
-}
-
-template <typename OpTy, bool IsMin>
-static void
-rewriteAffineOpAfterPeeling(RewriterBase &rewriter, TiledLoopOp mainLoop,
-                            TiledLoopOp remainderLoop, Value mainIv,
-                            Value remainderIv, Value ub, Value step) {
-  mainLoop.walk([&](OpTy affineOp) {
-    AffineMap map = affineOp.getAffineMap();
-    (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
-                                     affineOp.operands(), IsMin, mainIv, ub,
-                                     step, /*insideLoop=*/true);
-  });
-  remainderLoop.walk([&](OpTy affineOp) {
-    AffineMap map = affineOp.getAffineMap();
-    (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
-                                     affineOp.operands(), IsMin, remainderIv,
-                                     ub, step, /*insideLoop=*/false);
-  });
-}
-
-LogicalResult mlir::linalg::peelAndCanonicalizeTiledLoop(RewriterBase &rewriter,
-                                                         TiledLoopOp loopOp,
-                                                         int64_t idx,
-                                                         TiledLoopOp &result) {
-  int64_t numLoops = loopOp.iterator_types().size();
-  if (idx < 0 || numLoops <= idx)
-    return failure();
-
-  Value ub = loopOp.upperBound()[idx];
-  TiledLoopOp remainderLoop;
-  Value splitBound;
-  if (failed(peelTiledLoop(rewriter, loopOp, idx, remainderLoop, splitBound)))
-    return failure();
-
-  // Rewrite affine.min and affine.max ops.
-  Value mainIv = loopOp.getInductionVars()[idx], step = loopOp.step()[idx],
-        remainderIv = remainderLoop.getInductionVars()[idx];
-
-  rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>(
-      rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
-  rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>(
-      rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
-
-  result = remainderLoop;
-  return success();
-}
-
-void mlir::linalg::populateTiledLoopToSCFPattern(RewritePatternSet &patterns) {
-  patterns.add<TiledLoopToSCFPattern>(patterns.getContext());
-}
-
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createConvertLinalgTiledLoopsToSCFPass() {
-  return std::make_unique<LowerTiledLoopsToSCF>();
-}
-
 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() {
   return std::make_unique<LowerToLoops>();
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 0271857383746..2e1418c529a25 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -271,8 +271,6 @@ mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
   case LinalgTilingLoopType::ParallelLoops:
     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
-  case LinalgTilingLoopType::TiledLoops:
-    return tileLinalgOpImpl<linalg::TiledLoopOp>(b, op, options);
   default:;
   }
   return failure();
@@ -453,13 +451,10 @@ static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
 namespace {
 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
   LinalgTilingPass() = default;
-  LinalgTilingPass(ArrayRef<int64_t> tileSizes, LinalgTilingLoopType loopType,
-                   ArrayRef<StringRef> distributionTypes) {
+  LinalgTilingPass(ArrayRef<int64_t> tileSizes, LinalgTilingLoopType loopType) {
     this->tileSizes = tileSizes;
     this->loopType = "";
     this->loopTypeEnum = loopType;
-    this->distributionTypes = llvm::to_vector<2>(llvm::map_range(
-        distributionTypes, [](StringRef ref) { return ref.str(); }));
   }
 
   void runOnOperation() override {
@@ -469,14 +464,9 @@ struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
             .Case("for", LinalgTilingLoopType::Loops)
             .Case("affine", LinalgTilingLoopType::AffineLoops)
             .Case("parallel", LinalgTilingLoopType::ParallelLoops)
-            .Case("tiled_loop", LinalgTilingLoopType::TiledLoops)
             .Default(loopTypeEnum);
-    auto distTypes = llvm::to_vector<2>(llvm::map_range(
-        distributionTypes, [](std::string &str) { return StringRef(str); }));
-    auto options = LinalgTilingOptions()
-                       .setTileSizes(tileSizes)
-                       .setLoopType(type)
-                       .setDistributionTypes(distTypes);
+    auto options =
+        LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(type);
     MLIRContext *ctx = funcOp.getContext();
     RewritePatternSet patterns(ctx);
     insertTilingPatterns(patterns, options);
@@ -501,8 +491,6 @@ struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
 
 std::unique_ptr<OperationPass<FuncOp>>
 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes,
-                             linalg::LinalgTilingLoopType loopType,
-                             ArrayRef<StringRef> distributionTypes) {
-  return std::make_unique<LinalgTilingPass>(tileSizes, loopType,
-                                            distributionTypes);
+                             linalg::LinalgTilingLoopType loopType) {
+  return std::make_unique<LinalgTilingPass>(tileSizes, loopType);
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9f177f0a1b92b..f6a5304e1cff9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -299,18 +299,6 @@ static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
       .Default([&](Operation *op) { return op->getResults(); });
 }
 
-/// Try to peel a TiledLoopOp and return the new result.
-static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
-                                      TiledLoopOp tiledLoop, int64_t idx) {
-  assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
-         "requested peeling of non-existing loop");
-  TiledLoopOp result;
-  if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
-    return result->getResults();
-  assert(!result && "expected that loop was not peeled");
-  return tiledLoop->getResults();
-}
-
 /// Peel loops after tiling.
 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
                                      ArrayRef<int64_t> peeledLoops,
@@ -320,17 +308,7 @@ void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
            "requested peeling of non-existing loop");
     SmallVector<Value, 4> loopResults;
     Operation *loopOp = res.loops[loop];
-    if (loopType == LinalgTilingLoopType::TiledLoops) {
-      assert(llvm::all_of(
-                 res.loops,
-                 [&](Operation *op) { return op == res.loops.front(); }) &&
-             "expected that all loop ops are the same TiledLoopOp");
-      auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
-      assert(tiledLoopOp && "expected TiledLoopOp");
-      loopResults = peelLoop(rewriter, tiledLoopOp, loop);
-    } else {
-      loopResults = peelLoop(rewriter, loopOp);
-    }
+    loopResults = peelLoop(rewriter, loopOp);
 
     // The result of the loop nest may change with peeling.
     if (res.tensorResults.size() == loopOp->getNumResults() &&

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 98a62a0f3cd6f..3dfb336a9bf17 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -125,7 +125,6 @@ RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
-template struct mlir::linalg::GenerateLoopNest<TiledLoopOp>;
 
 /// Given a list of subview ranges, extract individual values for lower, upper
 /// bounds and steps and put them into the corresponding vectors.
@@ -537,39 +536,6 @@ void GenerateLoopNest<AffineForOp>::doit(
                             });
 }
 
-/// Specialization to build an linalg.tiled_loop
-template <>
-void GenerateLoopNest<TiledLoopOp>::doit(
-    OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
-    ArrayRef<Attribute> iteratorTypes,
-    function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
-                                  ValueRange)>
-        bodyBuilderFn,
-    Optional<LinalgLoopDistributionOptions> distributionOptions,
-    ArrayRef<StringRef> distributionTypes) {
-  SmallVector<ProcInfo, 2> procInfo;
-  SmallVector<Value, 4> lbs, ubs, steps;
-  unpackRanges(loopRanges, lbs, ubs, steps);
-
-  auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
-                              ValueRange ivs, ValueRange inputs,
-                              ValueRange outputs) {
-    SmallVector<Value> operandValuesToUse = inputs;
-    operandValuesToUse.append(outputs.begin(), outputs.end());
-    scf::ValueVector results =
-        bodyBuilderFn(nestedBuilder, nestedLoc, ivs, operandValuesToUse);
-    nestedBuilder.create<linalg::YieldOp>(nestedLoc, results);
-  };
-
-  SmallVector<Value> inputOperands = linalgOp.getInputOperands();
-  SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
-  auto tiledLoop =
-      b.create<TiledLoopOp>(loc, lbs, ubs, steps, inputOperands, outputOperands,
-                            b.getArrayAttr(iteratorTypes), wrappedBuilderFn);
-  if (!distributionTypes.empty())
-    tiledLoop.setDistributionTypes(b, distributionTypes);
-}
-
 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
                                        Value nprocs, Value &lb, Value &ub,

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index c3405887431ff..e3f213f8cd6ef 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -18,31 +18,6 @@ func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
 
 // -----
 
-#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-
-// CHECK-LABEL: func @memref_cast_into_tiled_loop(
-func @memref_cast_into_tiled_loop(%arg0: memref<192xf32>)  {
-  %0 = memref.cast %arg0
-    : memref<192xf32> to memref<192xf32, #map>
-  %cst = arith.constant 0.000000e+00 : f32
-  %c24 = arith.constant 24 : index
-  %c0 = arith.constant 0 : index
-  %c192 = arith.constant 192 : index
-  // CHECK: linalg.tiled_loop
-  // CHECK-SAME: outs (%{{.*}} = %{{.*}}: memref<192xf32>)
-  linalg.tiled_loop (%arg3) = (%c0) to (%c192) step (%c24)
-    outs (%out = %0: memref<192xf32, #map>) {
-    %14 = affine.min affine_map<(d0) -> (-d0 + 192, 24)>(%arg3)
-    %16 = memref.subview %out[%arg3] [%14] [1]
-      : memref<192xf32, #map> to memref<?xf32, #map>
-    linalg.fill(%cst, %16) : f32, memref<?xf32, #map>
-    linalg.yield
-  }
-  return
-}
-
-// -----
-
 #accesses = [
   affine_map<(i) -> (i)>
 ]
@@ -368,70 +343,6 @@ func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32
 }
 
 
-// -----
-
-func private @foo(%A: memref<48xf32>, %B: tensor<48xf32>,
-                  %C: memref<48xf32>) -> (tensor<48xf32>)
-
-func @fold_tiled_loop_results(%A: memref<48xf32>, %B: tensor<48xf32>,
-    %C: memref<48xf32>, %C_tensor: tensor<48xf32>) -> tensor<48xf32> {
-  %c0 = arith.constant 0 : index
-  %c24 = arith.constant 24 : index
-  %c48 = arith.constant 48 : index
-  %useful, %useless = linalg.tiled_loop (%i) = (%c0) to (%c48) step (%c24)
-      ins (%A_ = %A: memref<48xf32>)
-      outs (%B_ = %B: tensor<48xf32>,
-            %CT_ = %C_tensor: tensor<48xf32>,
-            %C_ = %C: memref<48xf32>) {
-        %result = call @foo(%A_, %B_, %C_)
-          : (memref<48xf32>, tensor<48xf32>, memref<48xf32>)-> (tensor<48xf32>)
-    linalg.yield %result, %CT_ : tensor<48xf32>, tensor<48xf32>
-  }
-  return %useful : tensor<48xf32>
-}
-
-// CHECK-LABEL: func @fold_tiled_loop_results(
-// CHECK-SAME:   %[[A:.*]]: [[BUF_TY:memref<48xf32>]], %[[B:.*]]: [[TY:tensor<48xf32>]],
-// CHECK-SAME:   %[[C:.*]]: [[BUF_TY]],  %[[C_TENSOR:.*]]: [[TY]]) -> [[TY]] {
-
-// CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:  %[[C24:.*]] = arith.constant 24 : index
-// CHECK-DAG:  %[[C48:.*]] = arith.constant 48 : index
-
-// CHECK-NOT: %{{.*}} = linalg.tiled_loop
-// CHECK:  %[[RESULT:.*]] = linalg.tiled_loop (%{{.*}}) = (%[[C0]])
-// CHECK-SAME: to (%[[C48]]) step (%[[C24]])
-// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: [[BUF_TY]])
-// CHECK-SAME: outs (%[[B_:.*]] = %[[B]]: [[TY]], %[[C_:.*]] = %[[C]]: [[BUF_TY]]) {
-// CHECK-NEXT:   %[[RES:.*]] = call @foo(%[[A_]], %[[B_]], %[[C_]])
-// CHECK-NEXT:   linalg.yield %[[RES]] :
-
-// CHECK: return %[[RESULT]]
-
-// -----
-
-func private @foo(%A: memref<192xf32>, %B: tensor<192xf32>) -> tensor<192xf32>
-
-func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
-                             %B_tensor: tensor<192xf32>) -> tensor<192xf32> {
-  %c0 = arith.constant 0 : index
-  %c24 = arith.constant 24 : index
-  %c192 = arith.constant 192 : index
-  %result = linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24)
-      ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>)
-      outs (%BT_ = %B_tensor: tensor<192xf32>) {
-    %0 = call @foo(%A_, %BT_) : (memref<192xf32>, tensor<192xf32>) -> tensor<192xf32>
-    linalg.yield %0 : tensor<192xf32>
-  }
-  return %result : tensor<192xf32>
-}
-
-// CHECK-LABEL: func @fold_tiled_loop_inputs
-// CHECK: %[[RESULT:.*]] = linalg.tiled_loop
-// CHECK-SAME: ins (%{{.*}} = %{{.*}}: memref<192xf32>)
-
-// CHECK: return %[[RESULT]]
-
 // -----
 
 func private @some_use(%i : index, %j : index)
@@ -470,108 +381,6 @@ func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
 
 // -----
 
-// CHECK-LABEL: func @dim_of_tiled_loop_input_no_canonicalize(
-//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
-//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
-//       CHECK:   linalg.tiled_loop {{.*}} outs (%[[o:.*]] =
-//       CHECK:     %[[dim:.*]] = tensor.dim %[[o]], %[[c0]]
-//       CHECK:     arith.index_cast %[[dim]]
-func @dim_of_tiled_loop_input_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %s: index)
-    -> tensor<?x?xf32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-  %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
-      to (%d0, %d1) step (%c1, %c1)
-      ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
-      outs (%out1 = %arg2 : tensor<?x?xf32>) {
-    %inner_dim = tensor.dim %out1, %c0 : tensor<?x?xf32>
-    %cast1 = arith.index_cast %inner_dim : index to i32
-    %cast2 = arith.sitofp %cast1 : i32 to f32
-    %fill = linalg.fill(%cast2, %out1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-    %slice = tensor.extract_slice %fill[0, 0][%s, %s][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-    linalg.yield %slice : tensor<?x?xf32>
-  }
-  return %r : tensor<?x?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @dim_of_tiled_loop_input(
-//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
-//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
-//       CHECK:   linalg.tiled_loop
-//       CHECK:     %[[dim:.*]] = tensor.dim %[[arg1]], %[[c0]]
-//       CHECK:     arith.index_cast %[[dim]]
-func @dim_of_tiled_loop_input(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
-    -> tensor<?x?xf32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-  %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
-      to (%d0, %d1) step (%c1, %c1)
-      ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
-      outs (%out1 = %arg2 : tensor<?x?xf32>) {
-    %inner_dim = tensor.dim %in1, %c0 : tensor<?x?xf32>
-    %cast1 = arith.index_cast %inner_dim : index to i32
-    %cast2 = arith.sitofp %cast1 : i32 to f32
-    %fill = linalg.fill(%cast2, %out1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-    linalg.yield %fill : tensor<?x?xf32>
-  }
-  return %r : tensor<?x?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @dim_of_tiled_loop_result(
-//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
-//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
-//       CHECK:   tensor.dim %[[arg2]], %[[c0]]
-func @dim_of_tiled_loop_result(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %s: index)
-    -> index {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-  %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
-      to (%d0, %d1) step (%c1, %c1)
-      ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
-      outs (%out1 = %arg2 : tensor<?x?xf32>) {
-    %1 = tensor.insert_slice %arg0 into %out1 [0, 0] [%s, %s] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-    linalg.yield %1 : tensor<?x?xf32>
-  }
-  %r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
-  return %r2 : index
-}
-
-// -----
-
-// CHECK-LABEL: func @dim_of_tiled_loop_result_no_canonicalize(
-//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
-//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
-//       CHECK:   %[[r:.*]] = linalg.tiled_loop
-//       CHECK:   tensor.dim %[[r]], %[[c0]]
-func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %s: index)
-    -> index {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-  %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
-      to (%d0, %d1) step (%c1, %c1)
-      ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
-      outs (%out1 = %arg2 : tensor<?x?xf32>) {
-    %1 = tensor.insert_slice %arg0 into %arg1 [0, 0] [%s, %s] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-    linalg.yield %1 : tensor<?x?xf32>
-  }
-  %r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
-  return %r2 : index
-}
-
-// -----
-
 // CHECK: func @fold_self_copy
 func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-NEXT: return

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 248a966a7a624..2d15bffbb0d64 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -639,7 +639,7 @@ func @scf_for_deps(
     %lb : index,
     %ub : index,
     %step : index)
-  -> (tensor<?xf32>, tensor<?xf32>)
+  -> (tensor<?xf32>)
 {
   // %r0 must be out of place because one use of %t in the subsequent production
   // of %r1 is read.
@@ -666,38 +666,9 @@ func @scf_for_deps(
     scf.yield %t : tensor<?xf32>
   }
 
-  // %r2 must be out of place because one use of %t in the subsequent production
-  // of %r3 is read.
-  //      CHECK: linalg.tiled_loop
-  // CHECK-NEXT: call
-  // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
-  // CHECK-NEXT: linalg.yield
-  // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
-  //      CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "false"]}
-  %r2 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step)
-        ins()
-        outs(%t = %B: tensor<?xf32>) {
-    call @some_use(%t) : (tensor<?xf32>) -> ()
-    linalg.yield %t : tensor<?xf32>
-  }
-
-  // %r3 bufferizes inplace fine.
-  //      CHECK: linalg.tiled_loop
-  // CHECK-NEXT: call
-  // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
-  // CHECK-NEXT: linalg.yield
-  // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
-  //      CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]}
-  %r3 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step)
-        ins()
-        outs(%t = %B: tensor<?xf32>) {
-    call @some_use(%t) : (tensor<?xf32>) -> ()
-    linalg.yield %t : tensor<?xf32>
-  }
-
   //      CHECK: return
-  // CHECK-SAME: __equivalent_func_args__ = [0, 1]
-  return %r1, %r3: tensor<?xf32>, tensor<?xf32>
+  // CHECK-SAME: __equivalent_func_args__ = [0]
+  return %r1: tensor<?xf32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index b76ece00c3123..b1e086c9e7035 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -640,146 +640,6 @@ func private @print_memref_f32(tensor<*xf32>)
 
 // -----
 
-func private @some_use(memref<?xf32>)
-
-#TILE_MAP = affine_map<(d0)[s0] -> (3, -d0 + s0)>
-
-//  CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
-//  CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-//  CHECK-DAG: #[[$TILE_MAP:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
-
-//      CHECK:  func @tiled_dot(
-// CHECK-SAME:    %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
-// CHECK-SAME:    %[[B:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
-// CHECK-SAME:    %[[c:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
-func @tiled_dot(
-    %A: tensor<?xf32> {linalg.inplaceable = false},
-    %B: tensor<?xf32> {linalg.inplaceable = false},
-    %c: tensor<f32> {linalg.inplaceable = true},
-    %effecting: memref<?xf32>)
-  -> tensor<f32>
-{
-  %c3 = arith.constant 3 : index
-  %c0 = arith.constant 0 : index
-
-  //     CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref<?xf32, #[[$DYN_1D_MAP:.*]]>
-  %0 = tensor.dim %A, %c0 : tensor<?xf32>
-
-  //     CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} %[[A]]{{.*}}%[[B]]{{.*}}outs{{.*}}%[[c]]
-  // CHECK-NOT: copy
-  %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3)
-       ins (%arg4 = %A: tensor<?xf32>, %use = %effecting : memref<?xf32>, %arg5 = %B: tensor<?xf32>)
-      outs (%arg6 = %c: tensor<f32>)
-      iterators["reduction"]
-  {
-    // CHECK-NOT:   alloc
-
-    %2 = tensor.dim %arg4, %c0 : tensor<?xf32>
-    %3 = affine.min #TILE_MAP(%arg3)[%2]
-
-    //     CHECK:   %[[SV_A:.*]] = memref.subview {{.*}}
-    %4 = tensor.extract_slice %arg4[%arg3] [%3] [1] : tensor<?xf32> to tensor<?xf32>
-    %5 = tensor.dim %arg5, %c0 : tensor<?xf32>
-    %6 = affine.min #TILE_MAP(%arg3)[%5]
-
-    //     CHECK:   %[[SV_B:.*]] = memref.subview {{.*}}
-    %7 = tensor.extract_slice %arg5[%arg3] [%6] [1] : tensor<?xf32> to tensor<?xf32>
-
-    //     CHECK:   linalg.dot ins(%[[SV_A]], %[[SV_B]] : memref<?xf32, #[[$DYN_1D_MAP:.*]]>, memref<?xf32, #[[$DYN_1D_MAP:.*]]>) outs(%{{.*}} : memref<f32, #[[$DYN_0D_MAP]]>)
-    %8 = linalg.dot ins(%4, %7 : tensor<?xf32>, tensor<?xf32>) outs(%arg6 : tensor<f32>) -> tensor<f32>
-
-    //     CHECK:   call @some_use(%{{.*}}) : (memref<?xf32>) -> ()
-    call @some_use(%use) : (memref<?xf32>) -> ()
-
-    linalg.yield %8 : tensor<f32>
-    //     CHECK:   linalg.yield
-    // CHECK-NOT:   tensor
-  }
-
-  //     CHECK: return
-  // CHECK-NOT: tensor
-  return %1 : tensor<f32>
-}
-
-// -----
-
-#TILE_MAP = affine_map<(d0)[s0] -> (3, -d0 + s0)>
-
-//  CHECK-DAG: #[[$DYN_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-
-//      CHECK:  func @tiled_fill(
-// CHECK-SAME:    %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_MAP]]>
-func @tiled_fill(%A: tensor<?xf32> {linalg.inplaceable = true}) -> tensor<?xf32> {
-  %c3 = arith.constant 3 : index
-  %c0 = arith.constant 0 : index
-  %f0 = arith.constant 0.0 : f32
-
-  //     CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref<?xf32, #[[$DYN_MAP:.*]]>
-  %0 = tensor.dim %A, %c0 : tensor<?xf32>
-
-  //     CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} outs{{.*}}%[[A]]
-  %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3)
-      outs (%arg1 = %A: tensor<?xf32>)
-      iterators["parallel"]
-  {
-    // CHECK-NOT:   alloc
-
-    %2 = tensor.dim %arg1, %c0 : tensor<?xf32>
-    %3 = affine.min #TILE_MAP(%arg3)[%2]
-
-    //     CHECK:   %[[SV_A:.*]] = memref.subview {{.*}}
-    %4 = tensor.extract_slice %arg1[%arg3] [%3] [1] : tensor<?xf32> to tensor<?xf32>
-
-    //     CHECK:   linalg.fill(%{{.*}}, %[[SV_A]]) : f32, memref<?xf32, #[[$DYN_MAP:.*]]>
-    %5 = linalg.fill(%f0, %4) : f32, tensor<?xf32> -> tensor<?xf32>
-    %6 = tensor.insert_slice %5 into %arg1[%arg3] [%3] [1] : tensor<?xf32> into tensor<?xf32>
-
-    linalg.yield %6 : tensor<?xf32>
-    //     CHECK:   linalg.yield
-    // CHECK-NOT:   tensor
-  }
-
-  //     CHECK: return
-  // CHECK-NOT: tensor
-  return %1 : tensor<?xf32>
-}
-
-// -----
-
-//      CHECK:  func @tiled_loop_yield_out_of_place(
-// CHECK-SAME:    %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #{{.*}}>,
-// CHECK-SAME:    %[[B:[a-zA-Z0-9]*]]: memref<?xf32, #{{.*}}>
-func @tiled_loop_yield_out_of_place(
-    %A: tensor<?xf32> {linalg.inplaceable = true},
-    %B: tensor<?xf32> {linalg.inplaceable = true})
-  -> tensor<?xf32>
-{
-  %c3 = arith.constant 3 : index
-  %c0 = arith.constant 0 : index
-  %f0 = arith.constant 0.0 : f32
-
-  //     CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref<?xf32, #[[$DYN_MAP:.*]]>
-  %0 = tensor.dim %A, %c0 : tensor<?xf32>
-
-  //     CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} outs{{.*}}%[[A]]
-  %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3)
-      outs (%arg1 = %A: tensor<?xf32>)
-      iterators["parallel"]
-  {
-    // CHECK-NOT:   alloc
-    //     CHECK:   memref.copy %[[B]], %[[A]]
-    linalg.yield %B : tensor<?xf32>
-    //     CHECK:   linalg.yield
-    // CHECK-NOT:   tensor
-  }
-
-  //     CHECK: return
-  // CHECK-NOT: tensor
-  return %1 : tensor<?xf32>
-}
-
-// -----
-
 // CHECK: #[[$DYNAMIC:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
 
 // CHECK: func private @external_func(memref<?xf32, #[[$DYNAMIC]]>)

diff  --git a/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
deleted file mode 100644
index e7689ac8b339f..0000000000000
--- a/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
+++ /dev/null
@@ -1,39 +0,0 @@
-// RUN: mlir-opt -test-linalg-distribution %s | FileCheck %s
-
-func private @foo(%A: tensor<64x64xf32>,
-                  %B: tensor<64x64xf32>) -> tensor<64x64xf32>
-
-func @distribute_for_gpu(%A: tensor<64x64xf32>,
-                         %B: tensor<64x64xf32>) -> tensor<64x64xf32> {
-  %c0 = arith.constant 0 : index
-  %c16 = arith.constant 16 : index
-  %c64 = arith.constant 64 : index
-  %c24 = arith.constant 24 : index
-  %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c64, %c64) step (%c24, %c16)
-      ins (%A_ = %A: tensor<64x64xf32>) outs (%B_ = %B:tensor<64x64xf32>)
-      distribution ["block_x", "block_y"] {
-    %0 = call @foo(%A_, %B_)
-      : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32>
-    linalg.yield %0 : tensor<64x64xf32>
-  }
-  return %0 : tensor<64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 * 24)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 * 16)>
-
-// CHECK-LABEL: func @distribute_for_gpu
-// CHECK:  %[[C64:.*]] = arith.constant 64 : index
-
-// CHECK-DAG:  %[[GPU_BLOCK_X:.*]] = gpu.block_id x
-// CHECK-DAG:  %[[GPU_GRID_DIM_X:.*]] = gpu.grid_dim x
-// CHECK-DAG:  %[[LB_I:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[GPU_BLOCK_X]]]
-// CHECK-DAG:  %[[STEP_I:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[GPU_GRID_DIM_X]]]
-
-// CHECK-DAG:  %[[GPU_BLOCK_Y:.*]] = gpu.block_id y
-// CHECK-DAG:  %[[GPU_GRID_DIM_Y:.*]] = gpu.grid_dim y
-// CHECK-DAG:  %[[LB_J:.*]] = affine.apply #[[$MAP1]](){{\[}}%[[GPU_BLOCK_Y]]]
-// CHECK-DAG:  %[[STEP_J:.*]] = affine.apply #[[$MAP1]](){{\[}}%[[GPU_GRID_DIM_Y]]]
-
-// CHECK:  linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = (%[[LB_I]], %[[LB_J]])
-// CHECK-SAME: to (%[[C64]], %[[C64]]) step (%[[STEP_I]], %[[STEP_J]])

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index 9eb0e35860f8f..ef315c96f16b4 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -1,5 +1,4 @@
 // RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse  --split-input-file | FileCheck %s --check-prefix=TLOOP
 
 module {
   func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
@@ -83,64 +82,6 @@ module {
 //      CHECK:   }
 //      CHECK:   return %[[RESULT]]
 
-// TLOOP-LABEL:  func @matmul_fusion(
-// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
-// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
-// TLOOP-SAME: %[[AB_INIT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
-// TLOOP-SAME: %[[C:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
-// TLOOP-SAME: %[[ABC_INIT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-
-// TLOOP-DAG:  %[[C32:.*]] = arith.constant 32 : index
-// TLOOP-DAG:  %[[C64:.*]] = arith.constant 64 : index
-// TLOOP-DAG:  %[[C16:.*]] = arith.constant 16 : index
-// TLOOP-DAG:  %[[C0:.*]] = arith.constant 0 : index
-// TLOOP-DAG:  %[[C1:.*]] = arith.constant 1 : index
-
-// TLOOP:  %[[DIM_A0:.*]] = tensor.dim %[[A]], %[[C0]] : [[TY:.*]]
-
-// TLOOP:  %[[ABC:.*]] = linalg.tiled_loop (%[[IV0:.*]]) = (%[[C0]]) 
-// TLOOP-SAME: to (%[[DIM_A0]]) step (%[[C32]]) 
-// TLOOP-SAME: ins (%[[C_:.*]] = %[[C]]: tensor<?x?xf32>,
-// TLOOP-SAME:      %[[A_:.*]] = %[[A]]: tensor<?x?xf32>,
-// TLOOP-SAME:      %[[B_:.*]] = %[[B]]: tensor<?x?xf32>,
-// TLOOP-SAME:      %[[AB_INIT_:.*]] = %[[AB_INIT]]: tensor<?x?xf32>)
-// TLOOP-SAME: outs (%[[ABC_INIT_:.*]] = %[[ABC_INIT]]: tensor<?x?xf32>) {
-
-// TLOOP:    %[[ABC_INIT_SUB:.*]] = tensor.extract_slice %[[ABC_INIT_]][%[[IV0]], 0]
-// TLOOP:    %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[IV0]], 0]
-// TLOOP:    %[[AB_INIT_SUB:.*]] = tensor.extract_slice %[[AB_INIT_]][%[[IV0]], 0]
-
-// TLOOP:    %[[AB_SUB:.*]] = linalg.matmul
-// TLOOP-SAME:  ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]]
-
-// TLOOP:    %[[DIM_B_1:.*]] = tensor.dim %[[B]], %[[C1]] : [[TY]]
-// TLOOP:    %[[DIM_C_1:.*]] = tensor.dim %[[C]], %[[C1]] : [[TY]]
-
-// TLOOP:    %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) = 
-// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]])
-// TLOOP-SAME: step (%[[C64]], %[[C16]]) 
-// TLOOP-SAME: ins (%[[AB_SUB_:.*]] = %[[AB_SUB]]: [[TY]],
-// TLOOP-SAME:      %[[C__:.*]] = %[[C_]]: [[TY]])
-// TLOOP-SAME: outs (%[[ABC_INIT_SUB_:.*]] = %[[ABC_INIT_SUB]]: [[TY]])
-// TLOOP-SAME: iterators["parallel", "reduction"] {
-
-// TLOOP:      %[[AB_SUB_SUB:.*]] = tensor.extract_slice %[[AB_SUB_]][0, %[[IV2]]]
-// TLOOP:      %[[C__SUB:.*]] = tensor.extract_slice %[[C__]][%[[IV2]], %[[IV1]]]
-// TLOOP:      %[[ABS_INIT_SUB_SUB:.*]] = tensor.extract_slice %[[ABC_INIT_SUB_]][0, %[[IV1]]]
-
-// TLOOP:      %[[ABC_SUB_SUB:.*]] = linalg.matmul
-// TLOOP-SAME:  ins(%[[AB_SUB_SUB]], %[[C__SUB]] : [[TY]], [[TY]])
-// TLOOP-SAME:  outs(%[[ABS_INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
-
-// TLOOP:      %[[RES0:.*]] = tensor.insert_slice %[[ABC_SUB_SUB]]
-// TLOOP-SAME:   into %[[ABC_INIT_SUB_]][0, %[[IV1]]]
-// TLOOP:      linalg.yield %[[RES0]] : [[TY]]
-// TLOOP:    }
-// TLOOP:    %[[RES1:.*]] = tensor.insert_slice %[[ABC_SUB_]] into %[[ABC_INIT_]][%[[IV0]], 0]
-// TLOOP:    linalg.yield %[[RES1]] : [[TY]]
-// TLOOP:  }
-// TLOOP:  return %[[ABC]] : [[TY]]
-
 // -----
 
 module {
@@ -195,48 +136,6 @@ module {
 //       CHECK:     scf.yield %[[YIELD]]
 //       CHECK:   return %[[RESULT]]
 
-// TLOOP-LABEL: func @matmul_plus_matmul
-// TLOOP-SAME:    %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
-// TLOOP-SAME:    %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
-// TLOOP-SAME:    %[[AB:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-
-// TLOOP-DAG:  %[[C32:.*]] = arith.constant 32 : index
-// TLOOP-DAG:  %[[C64:.*]] = arith.constant 64 : index
-// TLOOP-DAG:  %[[C0:.*]] = arith.constant 0 : index
-// TLOOP-DAG:  %[[C1:.*]] = arith.constant 1 : index
-
-// TLOOP:  %[[DIM_A_0:.*]] = tensor.dim %[[A]], %[[C0]] : [[TY:.*]]
-// TLOOP:  %[[DIM_B_1:.*]] = tensor.dim %[[B]], %[[C1]] : [[TY]]
-
-// TLOOP:  %[[INIT:.*]] = linalg.init_tensor [%[[DIM_A_0]], %[[DIM_B_1]]]
-
-// TLOOP:  %[[RESULT:.*]] = linalg.tiled_loop (%[[IV0:.*]], %[[IV1:.*]]) =
-// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
-// TLOOP-SAME: step (%[[C32]], %[[C64]])
-// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
-// TLOOP-SAME:      %[[B_:.*]] = %[[B]]: [[TY]],
-// TLOOP-SAME:      %[[AB_:.*]] = %[[AB]]: [[TY]])
-// TLOOP-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: [[TY]]) {
-
-// TLOOP:    %[[INIT_SUB:.*]] = tensor.extract_slice %[[INIT_]][%[[IV0]], %[[IV1]]]
-// TLOOP:    %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[IV0]], 0]
-// TLOOP:    %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[IV1]]]
-// TLOOP:    %[[AB_SUB_INIT:.*]] = tensor.extract_slice %[[AB_]][%[[IV0]], %[[IV1]]]
-
-// TLOOP:    %[[AB_SUB:.*]] = linalg.matmul
-// TLOOP-SAME:  ins(%[[A_SUB]], %[[B_SUB]] : [[TY]], [[TY]])
-// TLOOP-SAME:  outs(%[[AB_SUB_INIT]] : [[TY]])
-
-// TLOOP:    %[[DOUBLE_AB:.*]] = linalg.generic
-// TLOOP-SAME:  ins(%[[AB_SUB]] : [[TY]]) outs(%[[INIT_SUB]] : [[TY]])
-
-// TLOOP:    %[[RESULT_SUB:.*]] = tensor.insert_slice
-// TLOOP-SAME:  %[[DOUBLE_AB:.*]] into %[[INIT_]][%[[IV0]], %[[IV1]]]
-
-// TLOOP:    linalg.yield %[[RESULT_SUB]] : [[TY]]
-// TLOOP:  }
-// TLOOP:  return %[[RESULT]] : [[TY]]
-
 // -----
 
 module {
@@ -270,59 +169,6 @@ module {
 //       CHECK:     %[[MM:.*]] = tensor.insert_slice %[[ST_MM_RES]] into {{.*}}
 //       CHECK:     scf.yield %[[MM]] : tensor<?x?xf32>
 
-
-// TLOOP-LABEL: func @matmul_out_fusion(
-// TLOOP-SAME:    %[[OUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// TLOOP-SAME:    %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// TLOOP-SAME:    %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-
-// TLOOP-DAG:  %[[C0_F32:.*]] = arith.constant 0.0
-// TLOOP-DAG:  %[[C32:.*]] = arith.constant 32 : index
-// TLOOP-DAG:  %[[C64:.*]] = arith.constant 64 : index
-// TLOOP-DAG:  %[[C16:.*]] = arith.constant 16 : index
-// TLOOP-DAG:  %[[C0:.*]] = arith.constant 0 : index
-// TLOOP-DAG:  %[[C1:.*]] = arith.constant 1 : index
-
-// TLOOP:  %[[DIM_A_0:.*]] = tensor.dim %[[A]], %[[C0]] : [[TY:.*]]
-// TLOOP:  %[[DIM_B_1:.*]] = tensor.dim %[[B]], %[[C1]] : [[TY]]
-
-// TLOOP:  %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) =
-// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
-// TLOOP-SAME: step (%[[C32]], %[[C64]])
-// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
-// TLOOP-SAME:      %[[B_:.*]] = %[[B]]: [[TY]],
-// TLOOP-SAME:      %[[C0_F32_:.*]] = %[[C0_F32]]
-// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
-
-// TLOOP:    %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]]
-// TLOOP:    %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
-// TLOOP:    %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
-// TLOOP:    %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
-// TLOOP:    %[[INIT_SUB:.*]] = linalg.fill(%[[C0_F32_]], %[[OUT_SUB]])
-
-// TLOOP:    %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
-// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
-// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]],
-// TLOOP-SAME:      %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]])
-// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]])
-// TLOOP-SAME: iterators["reduction"] {
-
-// TLOOP:      %[[A_SUB_SUB:.*]] = tensor.extract_slice %[[A_SUB_]][0, %[[K]]]
-// TLOOP:      %[[B_SUB_SUB:.*]] = tensor.extract_slice %[[B_SUB_]][%[[K]], 0]
-// TLOOP:      %[[INIT_SUB_SUB:.*]] = tensor.extract_slice %[[INIT_SUB_]][0, 0]
-
-// TLOOP:      %[[AB_SUB_SUB:.*]] = linalg.matmul
-// TLOOP-SAME:   ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
-// TLOOP-SAME:   outs(%[[INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
-// TLOOP:      %[[AB_SUB_:.*]] = tensor.insert_slice %[[AB_SUB_SUB]] into %[[INIT_SUB_]]
-// TLOOP:      linalg.yield %[[AB_SUB_]] : [[TY]]
-// TLOOP:    }
-// TLOOP:    %[[SUB_RESULT:.*]] = tensor.insert_slice %[[AB_SUB]]
-// TLOOP-SAME:  into %[[OUT_]][%[[I]], %[[J]]]
-// TLOOP:    linalg.yield %[[SUB_RESULT]] : [[TY]]
-// TLOOP:  }
-// TLOOP:  return %[[AB]] : [[TY]]
-
 // -----
 
 module {
@@ -343,58 +189,3 @@ module {
     return %1 : tensor<?x?xf32>
   }
 }
-
-// TLOOP-LABEL: func @generic_plus_matmul(
-// TLOOP-SAME:    %[[OUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// TLOOP-SAME:    %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// TLOOP-SAME:    %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-
-// TLOOP-DAG:  %[[C0_F32:.*]] = arith.constant 0.0
-// TLOOP-DAG:  %[[C32:.*]] = arith.constant 32 : index
-// TLOOP-DAG:  %[[C64:.*]] = arith.constant 64 : index
-// TLOOP-DAG:  %[[C16:.*]] = arith.constant 16 : index
-// TLOOP-DAG:  %[[C0:.*]] = arith.constant 0 : index
-// TLOOP-DAG:  %[[C1:.*]] = arith.constant 1 : index
-
-// TLOOP:  %[[DIM_A_0:.*]] = tensor.dim %[[A]], %[[C0]] : [[TY:.*]]
-// TLOOP:  %[[DIM_B_1:.*]] = tensor.dim %[[B]], %[[C1]] : [[TY]]
-
-// TLOOP:  %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) =
-// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
-// TLOOP-SAME: step (%[[C32]], %[[C64]])
-// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
-// TLOOP-SAME:      %[[B_:.*]] = %[[B]]: [[TY]],
-// TLOOP-SAME:      %[[C0_F32_:.*]] = %[[C0_F32]]
-// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
-
-// TLOOP:    %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]]
-// TLOOP:    %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
-// TLOOP:    %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
-// TLOOP:    %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
-// TLOOP:    %[[INIT_SUB:.*]] = linalg.generic
-// TLOOP-SAME: ins(%[[C0_F32_]]
-// TLOOP-SAME: outs(%[[OUT_SUB]]
-
-// TLOOP:    %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
-// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
-// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]],
-// TLOOP-SAME:      %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]])
-// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]])
-// TLOOP-SAME: iterators["reduction"] {
-
-// TLOOP:      %[[A_SUB_SUB:.*]] = tensor.extract_slice %[[A_SUB_]][0, %[[K]]]
-// TLOOP:      %[[B_SUB_SUB:.*]] = tensor.extract_slice %[[B_SUB_]][%[[K]], 0]
-// TLOOP:      %[[INIT_SUB_SUB:.*]] = tensor.extract_slice %[[INIT_SUB_]][0, 0]
-
-// TLOOP:      %[[AB_SUB_SUB:.*]] = linalg.matmul
-// TLOOP-SAME:   ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
-// TLOOP-SAME:   outs(%[[INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
-// TLOOP:      %[[AB_SUB_:.*]] = tensor.insert_slice %[[AB_SUB_SUB]] into %[[INIT_SUB_]]
-// TLOOP:      linalg.yield %[[AB_SUB_]] : [[TY]]
-// TLOOP:    }
-// TLOOP:    %[[SUB_RESULT:.*]] = tensor.insert_slice %[[AB_SUB]]
-// TLOOP-SAME:  into %[[OUT_]][%[[I]], %[[J]]]
-// TLOOP:    linalg.yield %[[SUB_RESULT]] : [[TY]]
-// TLOOP:  }
-// TLOOP:  return %[[AB]] : [[TY]]
-

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c45c58abf91c8..081df97b7a0fc 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -411,110 +411,6 @@ func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2
 
 // -----
 
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
-
-func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
-                  %C: memref<192x192xf32>) -> ()
-
-func @tiled_loop_incorrent_num_yield_operands(%A: memref<192x192xf32>,
-    %B: memref<192x192xf32>, %C: memref<192x192xf32>,
-    %C_tensor: tensor<192x192xf32>) {
-  %c24 = arith.constant 24 : index
-  %c0 = arith.constant 0 : index
-  %c192 = arith.constant 192 : index
-  %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
-      step (%c24, %c24)
-      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
-      outs (%CT_ = %C_tensor: tensor<192x192xf32>,
-            %C_ = %C: memref<192x192xf32>) {
-        call @foo(%A_, %B_, %C_)
-          : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
-    // expected-error @+1 {{expected number of tensor output args = 1 to match the number of yield operands = 0}}
-    linalg.yield
-  }
-  return
-}
-
-// -----
-
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
-
-func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
-                  %C: memref<192x192xf32>) -> tensor<f32>
-
-func @tiled_loop_incorrent_yield_operand_type(%A: memref<192x192xf32>,
-    %B: memref<192x192xf32>, %C: memref<192x192xf32>,
-    %C_tensor: tensor<192x192xf32>) {
-  %c24 = arith.constant 24 : index
-  %c0 = arith.constant 0 : index
-  %c192 = arith.constant 192 : index
-  %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
-      step (%c24, %c24)
-      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
-      outs (%CT_ = %C_tensor: tensor<192x192xf32>,
-            %C_ = %C: memref<192x192xf32>) {
-        %1 = call @foo(%A_, %B_, %C_)
-          : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> tensor<f32>
-    // expected-error @+1 {{expected yield operand 0 with type = 'tensor<f32>' to match output arg type = 'tensor<192x192xf32>}}
-    linalg.yield %1 : tensor<f32>
-  }
-  return
-}
-
-// -----
-
-func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
-                  %C: memref<192x192xf32>) -> ()
-
-func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>,
-    %B: memref<192x192xf32>, %C: memref<192x192xf32>,
-    %C_tensor: tensor<192x192xf32>) {
-  %c24 = arith.constant 24 : index
-  %c0 = arith.constant 0 : index
-  %c192 = arith.constant 192 : index
-  // expected-error @+1 {{expected iterator types array attribute size = 1 to match the number of loops = 2}}
-  %0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ( {
-    ^bb0(%arg4: index, %arg5: index, %A_: memref<192x192xf32>,
-         %B_: memref<192x192xf32>, %CT_: tensor<192x192xf32>,
-         %C_: memref<192x192xf32>):
-      call @foo(%A_, %B_, %C_)
-          : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
-      linalg.yield %CT_ : tensor<192x192xf32>
-    }) {
-      iterator_types = ["parallel"],
-      operand_segment_sizes = dense<2> : vector<5xi32>
-    } : (index, index, index, index, index, index, memref<192x192xf32>,
-      memref<192x192xf32>, tensor<192x192xf32>, memref<192x192xf32>
-    ) -> tensor<192x192xf32>
-  return
-}
-
-// -----
-
-func private @foo(%A: memref<100xf32>) -> ()
-
-func @tiled_loop_incorrent_block_arg_type(%A: memref<192xf32>) {
-  %c0 = arith.constant 0 : index
-  %c192 = arith.constant 192 : index
-  %c24 = arith.constant 24 : index
-  // expected-error @+1 {{expected output arg 0 with type = 'memref<192xf32>' to match region arg 1 type = 'memref<100xf32>'}}
-  "linalg.tiled_loop"(%c0, %c192, %c24, %A) ( {
-    ^bb0(%arg4: index, %A_: memref<100xf32>):
-      call @foo(%A_) : (memref<100xf32>)-> ()
-      linalg.yield
-    }) {
-      iterator_types = ["parallel"],
-      operand_segment_sizes = dense<[1, 1, 1, 0, 1]> : vector<5xi32>
-    } : (index, index, index, memref<192xf32>) -> ()
-  return
-}
-
-// -----
-
 #attrs = {
         indexing_maps = [
                 affine_map<(i) -> (3 - i)>,

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 0c70868cbc32c..f2957f73e6221 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -6,8 +6,6 @@
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
 // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
 
-// CHECK-DAG: #[[$id_2d:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$id_1d:.*]] = affine_map<(d0, d1, d2) -> (d1)>
 // CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 // CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
@@ -155,7 +153,7 @@ func @generic_without_inputs(%arg0 : memref<?x?x?xf32>) {
   linalg.generic  {indexing_maps = [#map0],
                    iterator_types = ["parallel", "parallel", "parallel"]}
                   outs(%arg0 : memref<?x?x?xf32>) {
-   ^bb0(%arg3: f32):  
+   ^bb0(%arg3: f32):
       %cst = arith.constant 0.000000e+00 : f32
       linalg.yield %cst : f32
     }
@@ -218,7 +216,7 @@ func @generic_with_multiple_tensor_outputs(
     iterator_types = ["reduction"]}
     ins(%arg0, %arg1 : tensor<?xi32>, tensor<?xi32>)
     outs(%1, %3 : tensor<i32>, tensor<i32>) {
-  ^bb0(%arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32):  
+  ^bb0(%arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32):
     %5 = arith.cmpi sge, %arg3, %arg5 : i32
     %6 = arith.select %5, %arg3, %arg5 : i32
     %7 = arith.cmpi eq, %arg3, %arg5 : i32
@@ -352,173 +350,3 @@ func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 // CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-
-// -----
-
-#accesses_4 = [
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (i, j)>
-]
-
-#trait_4 = {
-  indexing_maps = #accesses_4,
-  iterator_types = ["parallel", "parallel"]
-}
-
-func @tiled_loop(%lhs: tensor<24x64xi8>, %rhs: tensor<24x64xi8>,
-                 %out: tensor<24x64xi8>) -> tensor<24x64xi8> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4 = arith.constant 4 : index
- %c24 = arith.constant 24 : index
- %c64 = arith.constant 64 : index
- %prod = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
-      ins(%lhs_ = %lhs: tensor<24x64xi8>, %rhs_ = %rhs: tensor<24x64xi8>)
-      outs(%out_ = %out: tensor<24x64xi8>) {
-    %lhs_sub = tensor.extract_slice %lhs_[%i, 0] [%c4, %c64] [1, 1]
-        : tensor<24x64xi8> to tensor<?x?xi8>
-    %rhs_sub = tensor.extract_slice %rhs_[%i, 0] [%c4, %c64] [1, 1]
-        : tensor<24x64xi8> to tensor<?x?xi8>
-    %out_sub = tensor.extract_slice %out_[%i, 0] [%c4, %c64] [1, 1]
-        : tensor<24x64xi8> to tensor<?x?xi8>
-
-    %sum = linalg.generic #trait_4
-        ins(%lhs_sub, %rhs_sub : tensor<?x?xi8>, tensor<?x?xi8>)
-        outs(%out_sub : tensor<?x?xi8>) {
-      ^bb(%l: i8, %r: i8, %o: i8) :
-        %s = arith.addi %l, %r : i8
-        linalg.yield %s : i8
-      } -> tensor<?x?xi8>
-
-    %sum_sub = tensor.insert_slice %sum into %out_[%i, 0][%c4, %c64][1, 1]
-      : tensor<?x?xi8> into tensor<24x64xi8>
-    linalg.yield %sum_sub : tensor<24x64xi8>
-  }
-  return %prod : tensor<24x64xi8>
-}
-// CHECK-LABEL: func @tiled_loop
-// CHECK-NOT: iterators[
-
-// -----
-
-#id_3d = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-#id_2d = affine_map<(d0, d1, d2) -> (d0, d2)>
-#id_1d = affine_map<(d0, d1, d2) -> (d1)>
-
-#trait_5 = {
-  indexing_maps = [
-    #id_3d,
-    #id_2d,
-    #id_1d,
-    #id_1d
-  ],
-  iterator_types = ["reduction", "parallel", "reduction"]
-}
-
-func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
-                           %input_2d: tensor<16x32xf32>,
-                           %input_1d: tensor<24xf32>,
-                           %output: tensor<24xf32>) -> tensor<24xf32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
-  %c4 = arith.constant 4 : index
-  %c8 = arith.constant 8 : index
-  %X = tensor.dim %input_3d, %c0 : tensor<16x24x32xf32>
-  %Y = tensor.dim %input_3d, %c1 : tensor<16x24x32xf32>
-  %Z = tensor.dim %input_3d, %c2 : tensor<16x24x32xf32>
-  %result = linalg.tiled_loop (%i, %j, %k)
-      = (%c0, %c0, %c0) to (%X, %Y, %Z) step (%c2, %c4, %c8)
-      ins(%i3d_ = %input_3d: tensor<16x24x32xf32>,
-          %i2d_ = %input_2d: tensor<16x32xf32>,
-          %i1d_ = %input_1d: tensor<24xf32>)
-      outs(%o_ =  %output: tensor<24xf32>)
-      iterators["reduction", "parallel", "reduction"]
-      distribution["block_x", "block_y", "none"] {
-    %sub_3d = tensor.extract_slice %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1]
-      : tensor<16x24x32xf32> to tensor<2x4x8xf32>
-    %sub_2d = tensor.extract_slice %i2d_[%i, %k][2, 8][1, 1]
-      : tensor<16x32xf32> to tensor<2x8xf32>
-    %sub_1d = tensor.extract_slice %i1d_[%j] [4] [1]
-      : tensor<24xf32> to tensor<4xf32>
-    %sub_out = tensor.extract_slice %o_[%j] [4] [1]
-      : tensor<24xf32> to tensor<4xf32>
-    %acc = linalg.generic #trait_5
-      ins(%sub_3d, %sub_2d, %sub_1d
-        : tensor<2x4x8xf32>, tensor<2x8xf32>, tensor<4xf32>)
-      outs(%sub_out : tensor<4xf32>)  {
-    ^bb0(%i3d: f32, %i2d: f32, %i1d: f32, %o: f32):
-      %0 = arith.addf %i3d, %i2d : f32
-      %1 = arith.addf %0, %i1d : f32
-      linalg.yield %1 : f32
-    } -> tensor<4xf32>
-
-    %sum_sub = tensor.insert_slice %acc into %o_[%j][4][1]
-      : tensor<4xf32> into tensor<24xf32>
-    linalg.yield %sum_sub : tensor<24xf32>
-  }
-  return %result : tensor<24xf32>
-}
-// CHECK-LABEL: func @tiled_loop_reduction
-// CHECK: iterators[
-
-// -----
-
-#trait_6 = {
-  indexing_maps = [
-    #id_3d,
-    #id_2d,
-    #id_1d,
-    #id_1d
-  ],
-  iterator_types = ["reduction", "parallel", "reduction"]
-}
-#map_1 = affine_map<(d0, d1, d2)[s0] -> (d0 * 768 + s0 + d1 * 32 + d2)>
-#map_2 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
-#map_3 = affine_map<(d0)[s0] -> (d0 + s0)>
-
-func @tiled_loop_on_buffers(%input_3d: memref<16x24x32xf32>,
-                            %input_2d: memref<16x32xf32>,
-                            %input_1d: memref<24xf32>,
-                            %output: memref<24xf32>) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
-  %c4 = arith.constant 4 : index
-  %c8 = arith.constant 8 : index
-  %X = memref.dim %input_3d, %c0 : memref<16x24x32xf32>
-  %Y = memref.dim %input_3d, %c1 : memref<16x24x32xf32>
-  %Z = memref.dim %input_3d, %c2 : memref<16x24x32xf32>
-  linalg.tiled_loop (%i, %j, %k) = (%c0, %c0, %c0)
-      to (%X, %Y, %Z) step (%c2, %c4, %c8)
-      ins(%i3d_ = %input_3d: memref<16x24x32xf32>,
-          %i2d_ = %input_2d: memref<16x32xf32>,
-          %i1d_ = %input_1d: memref<24xf32>)
-      outs(%o_ =  %output: memref<24xf32>)
-      iterators["reduction", "parallel", "reduction"] {
-    %sub_3d = memref.subview %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1]
-      : memref<16x24x32xf32> to memref<2x4x8xf32, #map_1>
-    %sub_2d = memref.subview %i2d_[%i, %k][2, 8][1, 1]
-      : memref<16x32xf32> to memref<2x8xf32, #map_2>
-    %sub_1d = memref.subview %i1d_[%j] [4] [1]
-      : memref<24xf32> to memref<4xf32, #map_3>
-    %sub_out = memref.subview %o_[%j] [4] [1]
-      : memref<24xf32> to memref<4xf32, #map_3>
-    linalg.generic #trait_6
-      ins(%sub_3d, %sub_2d, %sub_1d
-        : memref<2x4x8xf32, #map_1>,
-          memref<2x8xf32, #map_2>,
-          memref<4xf32, #map_3>)
-      outs(%sub_out : memref<4xf32, #map_3>)  {
-    ^bb0(%i3d: f32, %i2d: f32, %i1d: f32, %o: f32):
-      %0 = arith.addf %i3d, %i2d : f32
-      %1 = arith.addf %0, %i1d : f32
-      linalg.yield %1 : f32
-    }
-    linalg.yield
-  }
-  return
-}
-// CHECK-LABEL: func @tiled_loop_on_buffers
-// CHECK: iterators[

diff  --git a/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir
index b18b5044246a3..e1b2def4aadee 100644
--- a/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir
@@ -4,12 +4,6 @@
 // RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern tile-sizes=256,128,512 peeled-loops=1,2" -canonicalize | \
 // RUN:     FileCheck %s -check-prefix=CHECK-PEEL-12
 
-// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern tile-sizes=256,128,512 loop-type=tiled_loop peeled-loops=0" -canonicalize | \
-// RUN:     FileCheck %s -check-prefix=CHECK-TILED-LOOP-PEEL-0
-
-// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern tile-sizes=256,128,512 loop-type=tiled_loop peeled-loops=0,1" -canonicalize | \
-// RUN:     FileCheck %s -check-prefix=CHECK-TILED-LOOP-PEEL-01
-
 //     CHECK-PEEL-0: func @matmul_static_tensor
 // CHECK-PEEL-0-DAG:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-PEEL-0-DAG:   %[[c128:.*]] = arith.constant 128 : index
@@ -51,42 +45,6 @@
 //     CHECK-PEEL-12:       linalg.matmul ins({{.*}} : tensor<?x?xf32>, tensor<?x36xf32>) outs({{.*}} : tensor<?x36xf32>)
 //     CHECK-PEEL-12:     }
 //     CHECK-PEEL-12:   }
-
-//     CHECK-TILED-LOOP-PEEL-0: func @matmul_static_tensor
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c128:.*]] = arith.constant 128 : index
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c256:.*]] = arith.constant 256 : index
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c512:.*]] = arith.constant 512 : index
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c1280:.*]] = arith.constant 1280 : index
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c1500:.*]] = arith.constant 1500 : index
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c1600:.*]] = arith.constant 1600 : index
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c1700:.*]] = arith.constant 1700 : index
-//     CHECK-TILED-LOOP-PEEL-0:   linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[c1280]], %[[c1700]], %[[c1600]]) step (%[[c256]], %[[c128]], %[[c512]])
-//     CHECK-TILED-LOOP-PEEL-0:     linalg.matmul ins({{.*}} : tensor<256x?xf32>, tensor<?x?xf32>) outs({{.*}} : tensor<256x?xf32>)
-//     CHECK-TILED-LOOP-PEEL-0:   }
-//     CHECK-TILED-LOOP-PEEL-0:   linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) = (%[[c1280]], %[[c0]], %[[c0]]) to (%[[c1500]], %[[c1700]], %[[c1600]]) step (%[[c256]], %[[c128]], %[[c512]])
-//     CHECK-TILED-LOOP-PEEL-0:     linalg.matmul ins({{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs({{.*}} : tensor<?x?xf32>)
-//     CHECK-TILED-LOOP-PEEL-0:   }
-
-//     CHECK-TILED-LOOP-PEEL-01: func @matmul_static_tensor
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c128:.*]] = arith.constant 128 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c256:.*]] = arith.constant 256 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c512:.*]] = arith.constant 512 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c1280:.*]] = arith.constant 1280 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c1500:.*]] = arith.constant 1500 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c1600:.*]] = arith.constant 1600 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c1664:.*]] = arith.constant 1664 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c1700:.*]] = arith.constant 1700 : index
-//     CHECK-TILED-LOOP-PEEL-01:   linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[c1280]], %[[c1664]], %[[c1600]]) step (%[[c256]], %[[c128]], %[[c512]])
-//     CHECK-TILED-LOOP-PEEL-01:     linalg.matmul ins({{.*}} : tensor<256x?xf32>, tensor<?x128xf32>) outs({{.*}} : tensor<256x128xf32>)
-//     CHECK-TILED-LOOP-PEEL-01:   }
-//     CHECK-TILED-LOOP-PEEL-01:   linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) = (%[[c0]], %[[c1664]], %[[c0]]) to (%[[c1280]], %[[c1700]], %[[c1600]]) step (%[[c256]], %[[c128]], %[[c512]])
-//     CHECK-TILED-LOOP-PEEL-01:     linalg.matmul ins({{.*}} : tensor<256x?xf32>, tensor<?x?xf32>) outs({{.*}} : tensor<256x?xf32>)
-//     CHECK-TILED-LOOP-PEEL-01:   }
-//     CHECK-TILED-LOOP-PEEL-01:   linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) = (%[[c1280]], %[[c0]], %[[c0]]) to (%[[c1500]], %[[c1700]], %[[c1600]]) step (%[[c256]], %[[c128]], %[[c512]])
-//     CHECK-TILED-LOOP-PEEL-01:     linalg.matmul ins({{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs({{.*}} : tensor<?x?xf32>)
-//     CHECK-TILED-LOOP-PEEL-01:   }
 func @matmul_static_tensor(%arg0: tensor<1500x1600xf32>, %arg1: tensor<1600x1700xf32>)
     -> tensor<1500x1700xf32> {
   %out = linalg.init_tensor [1500, 1700] : tensor<1500x1700xf32>
@@ -138,33 +96,6 @@ func @matmul_static_tensor(%arg0: tensor<1500x1600xf32>, %arg1: tensor<1600x1700
 //     CHECK-PEEL-12:       }
 //     CHECK-PEEL-12:     }
 //     CHECK-PEEL-12:   }
-
-//     CHECK-TILED-LOOP-PEEL-0: func @matmul_dynamic_tensor
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c128:.*]] = arith.constant 128 : index
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c256:.*]] = arith.constant 256 : index
-// CHECK-TILED-LOOP-PEEL-0-DAG:   %[[c512:.*]] = arith.constant 512 : index
-//     CHECK-TILED-LOOP-PEEL-0:   linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) = (%[[c0]], %[[c0]], %[[c0]]) to (%{{.*}}, %{{.*}}, %{{.*}}) step (%[[c256]], %[[c128]], %[[c512]])
-//     CHECK-TILED-LOOP-PEEL-0:     linalg.matmul ins({{.*}} : tensor<256x?xf32>, tensor<?x?xf32>) outs({{.*}} : tensor<256x?xf32>)
-//     CHECK-TILED-LOOP-PEEL-0:   }
-//     CHECK-TILED-LOOP-PEEL-0:   linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) = (%{{.*}}, %[[c0]], %[[c0]]) to (%{{.*}}, %{{.*}}, %{{.*}}) step (%[[c256]], %[[c128]], %[[c512]])
-//     CHECK-TILED-LOOP-PEEL-0:     linalg.matmul ins({{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs({{.*}} : tensor<?x?xf32>)
-//     CHECK-TILED-LOOP-PEEL-0:   }
-
-//     CHECK-TILED-LOOP-PEEL-01: func @matmul_dynamic_tensor
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c128:.*]] = arith.constant 128 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c256:.*]] = arith.constant 256 : index
-// CHECK-TILED-LOOP-PEEL-01-DAG:   %[[c512:.*]] = arith.constant 512 : index
-//     CHECK-TILED-LOOP-PEEL-01:   linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) = (%[[c0]], %[[c0]], %[[c0]]) to (%{{.*}}, %{{.*}}, %{{.*}}) step (%[[c256]], %[[c128]], %[[c512]])
-//     CHECK-TILED-LOOP-PEEL-01:     linalg.matmul ins({{.*}} : tensor<256x?xf32>, tensor<?x128xf32>) outs({{.*}} : tensor<256x128xf32>)
-//     CHECK-TILED-LOOP-PEEL-01:   }
-//     CHECK-TILED-LOOP-PEEL-01:   linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) = (%[[c0]], %{{.*}}, %[[c0]]) to (%{{.*}}, %{{.*}}, %{{.*}}) step (%[[c256]], %[[c128]], %[[c512]])
-//     CHECK-TILED-LOOP-PEEL-01:     linalg.matmul ins({{.*}} : tensor<256x?xf32>, tensor<?x?xf32>) outs({{.*}} : tensor<256x?xf32>)
-//     CHECK-TILED-LOOP-PEEL-01:   }
-//     CHECK-TILED-LOOP-PEEL-01:   linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) = (%{{.*}}, %[[c0]], %[[c0]]) to (%{{.*}}, %{{.*}}, %{{.*}}) step (%[[c256]], %[[c128]], %[[c512]])
-//     CHECK-TILED-LOOP-PEEL-01:     linalg.matmul ins({{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs({{.*}} : tensor<?x?xf32>)
-//     CHECK-TILED-LOOP-PEEL-01:   }
 func @matmul_dynamic_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
     -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 3e56c21b049c3..a1a65fa289104 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -1,5 +1,4 @@
 // RUN: mlir-opt %s -linalg-tile="tile-sizes=2,3,4" -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,3,4 loop-type=tiled_loop distribution-types=block_x,block_y,none" -split-input-file | FileCheck %s -check-prefix=TLOOP
 
 // CHECK-LABEL: func @matmul_tensors(
 // CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
@@ -28,39 +27,6 @@ func @matmul_tensors(
   return %0 : tensor<?x?xf32>
 }
 
-// TLOOP-LABEL: func @matmul_tensors
-// TLOOP-SAME: (%[[ARG_0:.*]]: [[TY:.*]], %[[ARG_1:.*]]: [[TY]],
-// TLOOP-SAME: %[[ARG_2:.*]]: [[TY]]) -> [[TY]] {
-
-// TLOOP-DAG: %[[C0:.*]] = arith.constant 0 : index
-// TLOOP-DAG: %[[C1:.*]] = arith.constant 1 : index
-// TLOOP-DAG: %[[C2:.*]] = arith.constant 2 : index
-// TLOOP-DAG: %[[C3:.*]] = arith.constant 3 : index
-// TLOOP-DAG: %[[C4:.*]] = arith.constant 4 : index
-
-// TLOOP: %[[ARG_0_X:.*]] = tensor.dim %[[ARG_0]], %[[C0]] : [[TY]]
-// TLOOP: %[[ARG_0_Y:.*]] = tensor.dim %[[ARG_0]], %[[C1]] : [[TY]]
-// TLOOP: %[[ARG_1_Y:.*]] = tensor.dim %[[ARG_1]], %[[C1]] : [[TY]]
-
-// TLOOP: %{{.*}} = linalg.tiled_loop (%[[I:.*]], %[[J:.*]], %[[K:.*]]) =
-// TLOOP-SAME: (%[[C0]], %[[C0]], %[[C0]])
-// TLOOP-SAME: to (%[[ARG_0_X]], %[[ARG_1_Y]], %[[ARG_0_Y]])
-// TLOOP-SAME: step (%[[C2]], %[[C3]], %[[C4]])
-// TLOOP-SAME: ins (%[[A0:.*]] = %[[ARG_0]]: [[TY]], %[[A1:.*]] = %[[ARG_1]]: [[TY]])
-// TLOOP-SAME: outs (%[[A2:.*]] = %[[ARG_2]]: [[TY]])
-// TLOOP-SAME: iterators["parallel", "parallel", "reduction"]
-// TLOOP-SAME: distribution["block_x", "block_y", "none"] {
-
-// TLOOP: %[[SUB_ARG_0:.*]] = tensor.extract_slice %[[A0]][%[[I]], %[[K]]]
-// TLOOP: %[[SUB_ARG_1:.*]] = tensor.extract_slice %[[A1]][%[[K]], %[[J]]]
-// TLOOP: %[[SUB_ARG_2:.*]] = tensor.extract_slice %[[A2]][%[[I]], %[[J]]]
-
-// TLOOP: %[[PROD:.*]] = linalg.matmul ins(%[[SUB_ARG_0]], %[[SUB_ARG_1]]
-// TLOOP-SE: outs(%[[SUB_ARG_2]] : [[TY]]) -> [[TY]]
-
-// TLOOP: %[[O:.*]] = tensor.insert_slice %[[PROD]] into %[[A2]][%[[I]], %[[J]]]
-// TLOOP: linalg.yield %[[O]] : [[TY]]
-
 // -----
 
 func @generic_op_tensors(
@@ -108,29 +74,6 @@ func @generic_op_tensors(
 //       CHECK: }
 //       CHECK: return %[[TD0]]
 
-// TLOOP-LABEL: func @generic_op_tensors(
-// TLOOP-SAME:    %[[ARG_0:.*]]: [[TY:.*]],
-// TLOOP-SAME:    %[[ARG_1:.*]]: [[TY]]) -> [[TY]] {
-
-// TLOOP-DAG: %[[C0:.*]] = arith.constant 0 : index
-// TLOOP-DAG: %[[C1:.*]] = arith.constant 1 : index
-// TLOOP-DAG: %[[C2:.*]] = arith.constant 2 : index
-// TLOOP-DAG: %[[C3:.*]] = arith.constant 3 : index
-// TLOOP-DAG: %[[C4:.*]] = arith.constant 4 : index
-
-// TLOOP:     %[[INIT:.*]] = linalg.init_tensor
-// TLOOP:     %[[ARG_0_X:.*]] = tensor.dim %[[ARG_0]], %[[C0]] : [[TY]]
-// TLOOP:     %[[ARG_0_Y:.*]] = tensor.dim %[[ARG_0]], %[[C1]] : [[TY]]
-// TLOOP:     %[[ARG_0_Z:.*]] = tensor.dim %[[ARG_0]], %[[C2]] : [[TY]]
-
-// TLOOP:     %{{.*}} = linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) =
-// TLOOP-SAME: (%[[C0]], %[[C0]], %[[C0]])
-// TLOOP-SAME: to (%[[ARG_0_X]], %[[ARG_0_Y]], %[[ARG_0_Z]])
-// TLOOP-SAME: step (%[[C2]], %[[C3]], %[[C4]])
-// TLOOP-SAME: ins (%{{.*}} = %[[ARG_0]]: [[TY]], %{{.*}} = %[[ARG_1]]: [[TY]])
-// TLOOP-SAME: outs (%{{.*}} = %[[INIT]]: [[TY]])
-// TLOOP-SAME: distribution["block_x", "block_y", "none"] {
-
 // -----
 
 //  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>

diff  --git a/mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir b/mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir
deleted file mode 100644
index 106fcee1b130e..0000000000000
--- a/mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir
+++ /dev/null
@@ -1,231 +0,0 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns=test-tiled-loop-peeling=2 -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-2
-// RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns=test-tiled-loop-peeling=0,1,2 -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012
-// RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns="test-tiled-loop-peeling=0,1,2 skip-partial" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012-SKIP-PARTIAL
-
-// CHECK-TILE-2-LABEL: func @tiled_loop_3d_tensor(
-//  CHECK-TILE-2-SAME:     %[[input:.*]]: tensor<?x?x?xf32>, %[[s0:.*]]: index, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-TILE-2-DAG:   %[[c0:.*]] = arith.constant 0 : index
-//   CHECK-TILE-2-DAG:   %[[c1:.*]] = arith.constant 1 : index
-//   CHECK-TILE-2-DAG:   %[[c2:.*]] = arith.constant 2 : index
-//       CHECK-TILE-2:   %[[dim0:.*]] = tensor.dim %[[input]], %[[c0]]
-//       CHECK-TILE-2:   %[[dim1:.*]] = tensor.dim %[[input]], %[[c1]]
-//       CHECK-TILE-2:   %[[dim2:.*]] = tensor.dim %[[input]], %[[c2]]
-//       CHECK-TILE-2:   %[[init_tensor:.*]] = linalg.init_tensor
-//       CHECK-TILE-2:   %[[split_bound:.*]] = affine.apply
-//       CHECK-TILE-2:   %[[r1:.*]] = linalg.tiled_loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[c0]])
-//  CHECK-TILE-2-SAME:       to (%[[dim0]], %[[dim1]], %[[split_bound]])
-//  CHECK-TILE-2-SAME:       step (%[[s0]], %[[s1]], %[[s2]])
-//  CHECK-TILE-2-SAME:       ins (%[[loop_in1:.*]] = %[[input]]: tensor<?x?x?xf32>)
-//  CHECK-TILE-2-SAME:       outs (%[[loop_out1:.*]] = %[[init_tensor]]: tensor<?x?x?xf32>) {
-//       CHECK-TILE-2:     %[[min0_1:.*]] = affine.min
-//       CHECK-TILE-2:     %[[min1_1:.*]] = affine.min
-//       CHECK-TILE-2:     %[[in_slice1:.*]] = tensor.extract_slice %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
-//       CHECK-TILE-2:     %[[out_slice1:.*]] = tensor.extract_slice %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
-//       CHECK-TILE-2:     %[[mod_slice1:.*]] = tensor.insert_slice %{{.*}} into %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
-//       CHECK-TILE-2:     linalg.yield %[[mod_slice1]]
-//       CHECK-TILE-2:   %[[r2:.*]] = linalg.tiled_loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[split_bound]])
-//  CHECK-TILE-2-SAME:       to (%[[dim0]], %[[dim1]], %[[dim2]])
-//  CHECK-TILE-2-SAME:       step (%[[s0]], %[[s1]], %[[s2]])
-//  CHECK-TILE-2-SAME:       ins (%[[loop_in2:.*]] = %[[input]]: tensor<?x?x?xf32>)
-//  CHECK-TILE-2-SAME:       outs (%[[loop_out2:.*]] = %[[r1]]: tensor<?x?x?xf32>) {
-//       CHECK-TILE-2:     %[[min0_2:.*]] = affine.min
-//       CHECK-TILE-2:     %[[min1_2:.*]] = affine.min
-//       CHECK-TILE-2:     %[[apply2:.*]] = affine.apply
-//       CHECK-TILE-2:     %[[in_slice2:.*]] = tensor.extract_slice %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
-//       CHECK-TILE-2:     %[[out_slice2:.*]] = tensor.extract_slice %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
-//       CHECK-TILE-2:     %[[mod_slice2:.*]] = tensor.insert_slice %{{.*}} into %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
-//       CHECK-TILE-2:     linalg.yield %[[mod_slice2]]
-//       CHECK-TILE-2:   return %[[r2]]
-
-// CHECK-TILE-012-LABEL: func @tiled_loop_3d_tensor
-//       CHECK-TILE-012:   linalg.tiled_loop {{.*}} {
-//       CHECK-TILE-012:     linalg.yield
-//       CHECK-TILE-012:   }
-//       CHECK-TILE-012:   linalg.tiled_loop {{.*}} {
-//       CHECK-TILE-012:     linalg.yield
-//       CHECK-TILE-012:   }
-//       CHECK-TILE-012:   linalg.tiled_loop {{.*}} {
-//       CHECK-TILE-012:     linalg.yield
-//       CHECK-TILE-012:   }
-//       CHECK-TILE-012:   linalg.tiled_loop {{.*}} {
-//       CHECK-TILE-012:     linalg.yield
-//       CHECK-TILE-012:   }
-//       CHECK-TILE-012:   linalg.tiled_loop {{.*}} {
-//       CHECK-TILE-012:     linalg.yield
-//       CHECK-TILE-012:   }
-//       CHECK-TILE-012:   linalg.tiled_loop {{.*}} {
-//       CHECK-TILE-012:     linalg.yield
-//       CHECK-TILE-012:   }
-//       CHECK-TILE-012:   linalg.tiled_loop {{.*}} {
-//       CHECK-TILE-012:     linalg.yield
-//       CHECK-TILE-012:   }
-//       CHECK-TILE-012:   linalg.tiled_loop {{.*}} {
-//       CHECK-TILE-012:     linalg.yield
-//       CHECK-TILE-012:   }
-//   CHECK-TILE-012-NOT: linalg.tiled_loop
-
-//      CHECK-TILE-012-SKIP-PARTIAL: func @tiled_loop_3d_tensor(
-// CHECK-TILE-012-SKIP-PARTIAL-SAME:     %[[input:.*]]: tensor<?x?x?xf32>
-//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[c0:.*]] = arith.constant 0 : index
-//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[c1:.*]] = arith.constant 1 : index
-//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[c2:.*]] = arith.constant 2 : index
-//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[dim0:.*]] = tensor.dim %[[input]], %[[c0]]
-//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[dim1:.*]] = tensor.dim %[[input]], %[[c1]]
-//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[dim2:.*]] = tensor.dim %[[input]], %[[c2]]
-//      CHECK-TILE-012-SKIP-PARTIAL:   %[[p0:.*]] = affine.apply #{{.*}}()[%[[dim0]]
-//      CHECK-TILE-012-SKIP-PARTIAL:   %[[p1:.*]] = affine.apply #{{.*}}()[%[[dim1]]
-//      CHECK-TILE-012-SKIP-PARTIAL:   %[[p2:.*]] = affine.apply #{{.*}}()[%[[dim2]]
-//      CHECK-TILE-012-SKIP-PARTIAL:   linalg.tiled_loop {{.*}} = (%[[c0]], %[[c0]], %[[c0]]) to (%[[p0]], %[[p1]], %[[p2]])
-//      CHECK-TILE-012-SKIP-PARTIAL:   linalg.tiled_loop {{.*}} = (%[[c0]], %[[c0]], %[[p2]]) to (%[[p0]], %[[p1]], %[[dim2]])
-//      CHECK-TILE-012-SKIP-PARTIAL:   linalg.tiled_loop {{.*}} = (%[[c0]], %[[p1]], %[[c0]]) to (%[[p0]], %[[dim1]], %[[dim2]])
-//      CHECK-TILE-012-SKIP-PARTIAL:   linalg.tiled_loop {{.*}} = (%[[p0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim1]], %[[dim2]])
-func @tiled_loop_3d_tensor(%arg0: tensor<?x?x?xf32>, %s0: index, %s1: index,
-                           %s2: index) -> tensor<?x?x?xf32> {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
-  %c8 = arith.constant 8 : index
-  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
-  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
-  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
-  %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor<?x?x?xf32>
-  %result = linalg.tiled_loop
-           (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2)
-           step (%s0, %s1, %s2) ins (%arg4 = %arg0: tensor<?x?x?xf32>)
-           outs (%arg5 = %output: tensor<?x?x?xf32>) {
-    %min0 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg1, %s0)[%dim0]
-    %min1 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg2, %s1)[%dim1]
-    %min2 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg3, %s2)[%dim2]
-    %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1]: tensor<?x?x?xf32> to tensor<?x?x?xf32>
-    %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
-    %comp = "computation"(%in_slice, %out_slice) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-    %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
-    linalg.yield %updated_slice : tensor<?x?x?xf32>
-  }
-  return %result : tensor<?x?x?xf32>
-}
-
-// -----
-
-// CHECK-TILE-2-LABEL: func @tiled_loop_3d_memref(
-//  CHECK-TILE-2-SAME:     %[[input:.*]]: memref<?x?x?xf32>, %[[output:.*]]: memref<?x?x?xf32>, %[[s0:.*]]: index, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-TILE-2-DAG:   %[[c0:.*]] = arith.constant 0 : index
-//   CHECK-TILE-2-DAG:   %[[c1:.*]] = arith.constant 1 : index
-//   CHECK-TILE-2-DAG:   %[[c2:.*]] = arith.constant 2 : index
-//       CHECK-TILE-2:   %[[dim0:.*]] = memref.dim %[[input]], %[[c0]]
-//       CHECK-TILE-2:   %[[dim1:.*]] = memref.dim %[[input]], %[[c1]]
-//       CHECK-TILE-2:   %[[dim2:.*]] = memref.dim %[[input]], %[[c2]]
-//       CHECK-TILE-2:   %[[split_bound:.*]] = affine.apply
-//       CHECK-TILE-2:   linalg.tiled_loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[c0]])
-//  CHECK-TILE-2-SAME:       to (%[[dim0]], %[[dim1]], %[[split_bound]])
-//  CHECK-TILE-2-SAME:       step (%[[s0]], %[[s1]], %[[s2]])
-//  CHECK-TILE-2-SAME:       ins (%[[loop_in1:.*]] = %[[input]]: memref<?x?x?xf32>)
-//  CHECK-TILE-2-SAME:       outs (%[[loop_out1:.*]] = %[[output]]: memref<?x?x?xf32>) {
-//       CHECK-TILE-2:     %[[min0_1:.*]] = affine.min
-//       CHECK-TILE-2:     %[[min1_1:.*]] = affine.min
-//       CHECK-TILE-2:     memref.subview %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
-//       CHECK-TILE-2:     linalg.yield
-//       CHECK-TILE-2:   linalg.tiled_loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[split_bound]])
-//  CHECK-TILE-2-SAME:       to (%[[dim0]], %[[dim1]], %[[dim2]])
-//  CHECK-TILE-2-SAME:       step (%[[s0]], %[[s1]], %[[s2]])
-//  CHECK-TILE-2-SAME:       ins (%[[loop_in2:.*]] = %[[input]]: memref<?x?x?xf32>)
-//  CHECK-TILE-2-SAME:       outs (%[[loop_out2:.*]] = %[[output]]: memref<?x?x?xf32>) {
-//       CHECK-TILE-2:     %[[min0_2:.*]] = affine.min
-//       CHECK-TILE-2:     %[[min1_2:.*]] = affine.min
-//       CHECK-TILE-2:     %[[apply2:.*]] = affine.apply
-//       CHECK-TILE-2:     memref.subview %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
-//       CHECK-TILE-2:     linalg.yield
-//       CHECK-TILE-2:   return
-
-// CHECK-TILE-012-LABEL: func @tiled_loop_3d_memref
-
-!memref_subview_type = type memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>
-
-func @tiled_loop_3d_memref(%arg0: memref<?x?x?xf32>, %output: memref<?x?x?xf32>,
-                           %s0: index, %s1: index, %s2: index) {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
-  %c8 = arith.constant 8 : index
-  %dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
-  %dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
-  %dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
-  linalg.tiled_loop
-           (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2)
-           step (%s0, %s1, %s2) ins (%arg4 = %arg0: memref<?x?x?xf32>)
-           outs (%arg5 = %output : memref<?x?x?xf32>) {
-    %min0 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg1, %s0)[%dim0]
-    %min1 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg2, %s1)[%dim1]
-    %min2 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg3, %s2)[%dim2]
-    %in_slice = memref.subview %arg4[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1]: memref<?x?x?xf32> to !memref_subview_type
-    "computation"(%in_slice) : (!memref_subview_type) -> memref<?x?x?xf32>
-    linalg.yield
-  }
-  return
-}
-
-// -----
-
-// CHECK-TILE-2-LABEL: func @step_1_do_not_peel
-//       CHECK-TILE-2:   linalg.tiled_loop
-//   CHECK-TILE-2-NOT:   linalg.tiled_loop
-
-// CHECK-TILE-012-LABEL: func @step_1_do_not_peel
-
-func @step_1_do_not_peel(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
-  %c8 = arith.constant 8 : index
-  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
-  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
-  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
-  %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor<?x?x?xf32>
-  %result = linalg.tiled_loop
-           (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2)
-           step (%c1, %c1, %c1) ins (%arg4 = %arg0: tensor<?x?x?xf32>)
-           outs (%arg5 = %output: tensor<?x?x?xf32>) {
-    %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1]: tensor<?x?x?xf32> to tensor<?x?x?xf32>
-    %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
-    %comp = "computation"(%in_slice, %out_slice) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-    %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
-    linalg.yield %updated_slice : tensor<?x?x?xf32>
-  }
-  return %result : tensor<?x?x?xf32>
-}
-
-// -----
-
-// CHECK-TILE-2-LABEL: func @divides_evenly_do_not_peel
-//       CHECK-TILE-2:   linalg.tiled_loop
-//   CHECK-TILE-2-NOT:   linalg.tiled_loop
-
-// CHECK-TILE-012-LABEL: func @divides_evenly_do_not_peel
-
-func @divides_evenly_do_not_peel(%arg0: tensor<?x?x?xf32>, %s: index)
-    -> tensor<?x?x?xf32> {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
-  %c8 = arith.constant 8 : index
-  %c64 = arith.constant 64 : index
-  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
-  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
-  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
-  %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor<?x?x?xf32>
-  %result = linalg.tiled_loop
-           (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %c64)
-           step (%s, %s, %c8) ins (%arg4 = %arg0: tensor<?x?x?xf32>)
-           outs (%arg5 = %output: tensor<?x?x?xf32>) {
-    %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1]: tensor<?x?x?xf32> to tensor<?x?x?xf32>
-    %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
-    %comp = "computation"(%in_slice, %out_slice) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-    %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
-    linalg.yield %updated_slice : tensor<?x?x?xf32>
-  }
-  return %result : tensor<?x?x?xf32>
-}

diff  --git a/mlir/test/Dialect/Linalg/tiled-loop-to-scf.mlir b/mlir/test/Dialect/Linalg/tiled-loop-to-scf.mlir
deleted file mode 100644
index 08d81b4f96411..0000000000000
--- a/mlir/test/Dialect/Linalg/tiled-loop-to-scf.mlir
+++ /dev/null
@@ -1,184 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-tiled-loops-to-scf --split-input-file | FileCheck %s
-
-
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
-
-func @tiled_loop(%A: memref<192x192xf32>,
-                 %B: memref<192x192xf32>,
-                 %C: memref<192x192xf32>) {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c24 = arith.constant 24 : index
-  %c16 = arith.constant 16 : index
-  %c0 = arith.constant 0 : index
-  %c192 = arith.constant 192 : index
-
-  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
-      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
-      outs (%C_ = %C: memref<192x192xf32>) {
-    %0 = affine.min #map0(%i)
-    %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1]
-      : memref<192x192xf32> to memref<?x192xf32, #map1>
-    %2 = affine.min #map2(%j)
-    %3 = memref.subview %B_[0, %j] [192, %2] [1, 1]
-      : memref<192x192xf32> to memref<192x?xf32, #map1>
-    %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1]
-      : memref<192x192xf32> to memref<?x?xf32, #map1>
-    linalg.fill(%cst, %4) : f32, memref<?x?xf32, #map1>
-    linalg.matmul ins(%1, %3 : memref<?x192xf32, #map1>,
-                               memref<192x?xf32, #map1>)
-                  outs(%4 : memref<?x?xf32, #map1>)
-    linalg.yield
-  }
-  return
-}
-
-// CHECK-LABEL: @tiled_loop
-// CHECK-SAME:  %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>,
-// CHECK-SAME:  %[[C:.*]]: memref<192x192xf32>) {
-// CHECK:       %[[C24:.*]] = arith.constant 24 : index
-// CHECK:       %[[C16:.*]] = arith.constant 16 : index
-// CHECK:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %[[C192:.*]] = arith.constant 192 : index
-// CHECK:       scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
-// CHECK-SAME:      to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]]) {
-// CHECK:         %[[A_sub:.*]] = memref.subview %[[A]][%[[I]]
-// CHECK:         %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]]
-// CHECK:         %[[C_sub:.*]] = memref.subview %[[C]][%[[I]]
-// CHECK:         linalg.fill
-// CHECK:         linalg.matmul
-
-// -----
-
-func @tiled_loop_reduction(%A: memref<192x192xf32>,
-                           %B: memref<192x192xf32>,
-                           %C: memref<f32>) {
-   %c24 = arith.constant 24 : index
-   %c16 = arith.constant 16 : index
-   %c0 = arith.constant 0 : index
-   %c192 = arith.constant 192 : index
-   %cst = arith.constant 0.000000e+00 : f32
-
-  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
-      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
-      outs (%C_ = %C: memref<f32>)
-      iterators["reduction", "reduction"] {
-    linalg.fill(%cst, %A_) : f32, memref<192x192xf32>
-    linalg.yield
-  }
-  return
-}
-
-// CHECK-LABEL: @tiled_loop_reduction
-// CHECK:       %[[C24:.*]] = arith.constant 24 : index
-// CHECK:       %[[C16:.*]] = arith.constant 16 : index
-// CHECK:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %[[C192:.*]] = arith.constant 192 : index
-// CHECK:       scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]]
-// CHECK:         scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]
-// CHECK:           linalg.fill
-
-// -----
-
-#strided_1d = affine_map<(d0)[s0] -> (d0 + s0)>
-#strided_2d = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
-
-func @tiled_loop_row_reduction(%A: memref<10x8xf32>,
-                               %B: memref<8xf32>) {
-   %c0 = arith.constant 0 : index
-   %c2 = arith.constant 2 : index
-   %c4 = arith.constant 4 : index
-   %c8 = arith.constant 8 : index
-   %c10 = arith.constant 10 : index
-   %cst = arith.constant 0.000000e+00 : f32
-
-  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c10, %c8) step (%c2, %c4)
-      ins (%A_ = %A: memref<10x8xf32>)
-      outs (%B_ = %B: memref<8xf32>)
-      iterators["reduction", "parallel"] {
-    %A_sub = memref.subview %A_[%i, %j][2, 4][1, 1]
-      : memref<10x8xf32> to memref<2x4xf32, #strided_2d>
-    %B_sub = memref.subview %B_[%j][4][1]
-      : memref<8xf32> to memref<4xf32, #strided_1d>
-    linalg.generic {
-        indexing_maps = [affine_map<(i, j) -> (i, j)>,
-                         affine_map<(i, j) -> (j)>],
-        iterator_types = ["reduction", "parallel"]}
-        ins(%A_sub : memref<2x4xf32, #strided_2d>)
-        outs(%B_sub : memref<4xf32, #strided_1d>) {
-      ^bb(%a: f32, %b: f32) :
-        %0 = arith.addf %a, %b: f32
-        linalg.yield %0 : f32
-      }
-    linalg.yield
-  }
-  return
-}
-
-// CHECK-LABEL: @tiled_loop_row_reduction
-
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
-
-// CHECK:     scf.parallel (%[[J:.*]]) = (%[[C0]]) to (%[[C8]]) step (%[[C4]])
-// CHECK-NEXT:  scf.for %[[I:.*]] = %[[C0]] to %[[C10]] step %[[C2]]
-// CHECK-NEXT:    memref.subview %arg{{[0-9]+}}[%[[I]], %[[J]]] [2, 4] [1, 1]
-// CHECK-SAME:      : memref<10x8xf32> to memref<2x4xf32, #map{{[0-9]+}}>
-// CHECK-NEXT:    memref.subview %arg{{[0-9]+}}[%[[J]]] [4] [1]
-// CHECK-SAME:      : memref<8xf32> to memref<4xf32, #map{{[0-9]+}}>
-
-// -----
-
-#strided_1d = affine_map<(d0)[s0] -> (d0 + s0)>
-#strided_2d = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
-
-func @tiled_loop_col_reduction(%A: memref<10x8xf32>,
-                               %B: memref<10xf32>) {
-   %c0 = arith.constant 0 : index
-   %c2 = arith.constant 2 : index
-   %c4 = arith.constant 4 : index
-   %c8 = arith.constant 8 : index
-   %c10 = arith.constant 10 : index
-   %cst = arith.constant 0.000000e+00 : f32
-
-  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c10, %c8) step (%c2, %c4)
-      ins (%A_ = %A: memref<10x8xf32>)
-      outs (%B_ = %B: memref<10xf32>)
-      iterators["parallel", "reduction"] {
-    %A_sub = memref.subview %A_[%i, %j][2, 4][1, 1]
-      : memref<10x8xf32> to memref<2x4xf32, #strided_2d>
-    %B_sub = memref.subview %B_[%i][2][1]
-      : memref<10xf32> to memref<2xf32, #strided_1d>
-    linalg.generic {
-        indexing_maps = [affine_map<(i, j) -> (i, j)>,
-                         affine_map<(i, j) -> (i)>],
-        iterator_types = ["parallel", "reduction"]}
-        ins(%A_sub : memref<2x4xf32, #strided_2d>)
-        outs(%B_sub : memref<2xf32, #strided_1d>) {
-      ^bb(%a: f32, %b: f32) :
-        %0 = arith.addf %a, %b: f32
-        linalg.yield %0 : f32
-      }
-    linalg.yield
-  }
-  return
-}
-
-// CHECK-LABEL: @tiled_loop_col_reduction
-
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
-
-// CHECK:     scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[C10]]) step (%[[C2]])
-// CHECK-NEXT:  scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C4]]
-// CHECK-NEXT:    memref.subview %arg{{[0-9]+}}[%[[I]], %[[J]]] [2, 4] [1, 1]
-// CHECK-SAME:      : memref<10x8xf32> to memref<2x4xf32, #map{{[0-9]+}}>
-// CHECK-NEXT:    memref.subview %arg{{[0-9]+}}[%[[I]]] [2] [1]
-// CHECK-SAME:      : memref<10xf32> to memref<2xf32, #map{{[0-9]+}}>

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index c74fb756b785f..1fe3db2e9e676 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -1,7 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRLinalgTestPasses
   TestLinalgCodegenStrategy.cpp
-  TestLinalgDistribution.cpp
   TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp
   TestLinalgHoisting.cpp

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp
deleted file mode 100644
index 342fed37ad600..0000000000000
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp
+++ /dev/null
@@ -1,79 +0,0 @@
-//===- TestLinalgDistribution.cpp - Test Linalg hoisting functions --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements logic for testing Linalg hoisting functions.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-using namespace mlir::linalg;
-
-template <gpu::Dimension Dim>
-static linalg::ProcInfo getGpuBlockInfo(OpBuilder &b, Location loc) {
-  Type indexType = b.getIndexType();
-  ProcInfo procInfo = {b.create<gpu::BlockIdOp>(loc, indexType, Dim),
-                       b.create<gpu::GridDimOp>(loc, indexType, Dim)};
-  return procInfo;
-}
-
-static LinalgLoopDistributionOptions getDistributionOptions() {
-  LinalgLoopDistributionOptions opts;
-  opts.procInfoMap.insert(
-      std::make_pair("block_x", getGpuBlockInfo<gpu::Dimension::x>));
-  opts.procInfoMap.insert(
-      std::make_pair("block_y", getGpuBlockInfo<gpu::Dimension::y>));
-  return opts;
-}
-
-namespace {
-struct TestLinalgDistribution
-    : public PassWrapper<TestLinalgDistribution, OperationPass<FuncOp>> {
-  StringRef getArgument() const final { return "test-linalg-distribution"; }
-  StringRef getDescription() const final { return "Test Linalg distribution."; }
-  TestLinalgDistribution() = default;
-  TestLinalgDistribution(const TestLinalgDistribution &pass) = default;
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<AffineDialect, gpu::GPUDialect>();
-  }
-
-  void runOnOperation() override;
-};
-} // namespace
-
-void TestLinalgDistribution::runOnOperation() {
-  auto funcOp = getOperation();
-  RewritePatternSet distributeTiledLoopsPatterns(&getContext());
-  populateLinalgDistributeTiledLoopPattern(
-      distributeTiledLoopsPatterns, getDistributionOptions(),
-      LinalgTransformationFilter(
-          ArrayRef<StringAttr>{},
-          {StringAttr::get(funcOp.getContext(), "distributed")})
-          .addFilter([](Operation *op) {
-            return success(!op->getParentOfType<linalg::TiledLoopOp>());
-          }));
-  (void)applyPatternsAndFoldGreedily(funcOp,
-                                     std::move(distributeTiledLoopsPatterns));
-  // Ensure we drop the marker in the end.
-  funcOp.walk([](LinalgOp op) {
-    op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
-  });
-}
-
-namespace mlir {
-namespace test {
-void registerTestLinalgDistribution() {
-  PassRegistration<TestLinalgDistribution>();
-}
-} // namespace test
-} // namespace mlir

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 8af5c46433631..c9c44bfc812ba 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -120,10 +120,6 @@ struct TestLinalgTransforms
       *this, "tile-sizes",
       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-  ListOption<unsigned> testTiledLoopPeeling{
-      *this, "test-tiled-loop-peeling",
-      llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
-      llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
   Option<bool> skipPartial{
       *this, "skip-partial",
       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
@@ -605,8 +601,7 @@ static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
           .Case("for", LinalgTilingLoopType::Loops)
           .Case("affine", LinalgTilingLoopType::AffineLoops)
-          .Case("parallel", LinalgTilingLoopType::ParallelLoops)
-          .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
+          .Case("parallel", LinalgTilingLoopType::ParallelLoops);
   auto linalgTilingOptions = linalg::LinalgTilingOptions()
                                  .setPeeledLoops(peeledLoops)
                                  .setLoopType(type);
@@ -626,76 +621,6 @@ static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
 
-namespace {
-/// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
-/// `idx`-th loop contains only "full" iterations and a second loop for the
-/// remaining partial iteration (if any).
-struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
-  TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
-      : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
-  }
-
-  LogicalResult matchAndRewrite(TiledLoopOp loopOp,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<int64_t> peeledLoops;
-    if (loopOp->hasAttr(kPeeledLoopsLabel)) {
-      auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
-      peeledLoops =
-          llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
-            return attr.cast<IntegerAttr>().getInt();
-          }));
-      // Check if the loop was already peeled.
-      if (llvm::find(peeledLoops, idx) != peeledLoops.end())
-        return failure();
-    }
-    if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
-      // No peeling of loop nests with a partial iteration.
-      return failure();
-
-    if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
-      return failure();
-
-    // Peel loop and canonicalize.
-    TiledLoopOp result;
-    if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
-                                                    result)))
-      return failure();
-
-    // Apply label, so that the same loop is not rewritten a second time.
-    peeledLoops.push_back(idx);
-    rewriter.updateRootInPlace(loopOp, [&]() {
-      loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
-    });
-    result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
-    result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
-
-    return success();
-  }
-
-  /// Index of loop to peel.
-  int64_t idx;
-
-  /// If set to true, do not peel TiledLoopOps with a partial iteration.
-  bool skipPartial;
-};
-} // namespace
-
-static void applyTiledLoopPeelingPattern(FuncOp funcOp,
-                                         ArrayRef<unsigned> loops,
-                                         bool skipPartial) {
-  MLIRContext *ctx = funcOp.getContext();
-  RewritePatternSet patterns(ctx);
-  for (unsigned idx : loops)
-    patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
-  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-
-  // Drop the markers.
-  funcOp.walk([](TiledLoopOp op) {
-    op->removeAttr(kPeeledLoopsLabel);
-    op->removeAttr(kPartialIterationLabel);
-  });
-}
-
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   auto lambda = [&](void *) {
@@ -739,9 +664,6 @@ void TestLinalgTransforms::runOnOperation() {
     return applyGeneralizePadTensorPatterns(getOperation());
   if (testSwapSubTensorPadTensor)
     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
-  if (testTiledLoopPeeling.hasValue())
-    return applyTiledLoopPeelingPattern(getOperation(), testTiledLoopPeeling,
-                                        skipPartial);
   if (testTilePattern)
     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
                             /*scalarizeDynamicDims=*/false);

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 704aab507883e..2bb690bede084 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -81,7 +81,6 @@ void registerTestGenericIRVisitorsPass();
 void registerTestGenericIRVisitorsInterruptPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
-void registerTestLinalgDistribution();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgFusionTransforms();
 void registerTestLinalgTensorFusionTransforms();
@@ -171,7 +170,6 @@ void registerTestPasses() {
   mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();
   mlir::test::registerTestLinalgCodegenStrategy();
-  mlir::test::registerTestLinalgDistribution();
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgFusionTransforms();
   mlir::test::registerTestLinalgTensorFusionTransforms();


        


More information about the Mlir-commits mailing list