[Mlir-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)

Hsiangkai Wang llvmlistbot at llvm.org
Wed Aug 14 05:45:06 PDT 2024


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/96184

>From bad2ae08252a7d95e4655cf4fe080004a440ecf9 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:44:27 +0100
Subject: [PATCH 1/7] [mlir][linalg] Implement TilingInterface for winograd
 operations

In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operations. Before converting winograd
operations into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.

Add a transform operation structured.decompose_winograd_op to decompose
winograd operations. Before applying the transform op, use tile_using_for
to tile the input data into supported size. The test case shows how to
tile and decompose winograd operations.
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  22 +-
 .../Linalg/TransformOps/LinalgTransformOps.td |  37 +++
 .../Dialect/Linalg/Transforms/Transforms.h    |  45 +++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 216 +++++++++++++++
 .../TransformOps/LinalgTransformOps.cpp       |  41 +++
 .../Linalg/Transforms/WinogradConv2D.cpp      |  18 ++
 .../transform-tile-and-winograd-rewrite.mlir  | 260 ++++++++++++++++++
 7 files changed, 633 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a9007c8db3078e..bc2df7082ef8e3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,8 +154,8 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradFilterTransformOp :
-    Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
+    [AllElementTypesMatch<["filter", "output"]>]> {
   let summary = "Winograd filter transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -193,8 +193,13 @@ def Linalg_WinogradFilterTransformOp :
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradInputTransformOp :
-    Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
+    [AllElementTypesMatch<["input", "output"]>,
+     DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Winograd input transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -232,8 +237,13 @@ def Linalg_WinogradInputTransformOp :
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradOutputTransformOp :
-    Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
+    [AllElementTypesMatch<["value", "output"]>,
+     DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Winograd output transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index ecc86999006db6..106f0d79d9792d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2697,4 +2697,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
   }];
 }
 
+def DecomposeWinogradOp : Op<Transform_Dialect,
+    "structured.decompose_winograd_op",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Decompose winograd operations. It will convert filter, input and output
+    transform operations into a combination of scf, tensor, and linalg
+    equivalent operations. Before applying this transform operations, users
+    need to tile winograd transform operations into supported sizes.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    operations.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 477ef7bfafb181..89f907cfc7ade7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1316,6 +1316,51 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
                                       linalg::Conv2DNhwcFhwcOp op, int64_t m,
                                       int64_t r);
 
+/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
+/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
+/// C. After the rewriting, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract filter<h x w> from filter<f x h x w x c>
+///     %ret = linalg.matmul G, %extracted
+///     %ret = linalg.matmul %ret, GT
+///     %inserted = insert %ret into filter<h x w x c x f>
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradFilterTransformOp op);
+
+/// Rewrite linalg.winograd_input_transform. The data layout of the input is
+/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
+/// C. After the rewriting, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract input<h x w> from input<n x h x w x c>
+///     %ret = linalg.matmul BT, %extracted
+///     %ret = linalg.matmul %ret, B
+///     %inserted = insert %ret into input<h x w x n x c>
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+                                  linalg::WinogradInputTransformOp op);
+
+/// Rewrite linalg.winograd_output_transform. The data layout of the output is
+/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
+/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
+/// F. After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %f = lo_f to hi_f step 1
+///     %extracted = extract input<h x w> from result<h x w x n x f>
+///     %ret = linalg.matmul AT, %extracted
+///     %ret = linalg.matmul %ret, A
+///     %inserted = insert %ret into ret<n x h x w x f>
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradOutputTransformOp op);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a101552e419bc8..47800dfc1d1a2a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2885,6 +2885,15 @@ LogicalResult WinogradFilterTransformOp::verify() {
 // WinogradInputTransformOp
 //===----------------------------------------------------------------------===//
 
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+                               Location loc) {
+  if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
+    auto intAttr = cast<IntegerAttr>(attr);
+    return builder.create<arith::ConstantOp>(loc, intAttr);
+  }
+  return opFoldResult.get<Value>();
+}
+
 LogicalResult WinogradInputTransformOp::verify() {
   auto inputType = cast<ShapedType>(getInput().getType());
   ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2922,6 +2931,113 @@ LogicalResult WinogradInputTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  auto indexType = builder.getIndexType();
+  auto zeroAttr = builder.getIntegerAttr(indexType, 0);
+  auto oneAttr = builder.getIntegerAttr(indexType, 1);
+  Value output = getOutput();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zeroAttr;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = oneAttr;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(offsets[2]);
+  resultOffsets.push_back(offsets[3]);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(sizes[0]);
+  resultSizes.push_back(sizes[1]);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(sizes[4]);
+  resultSizes.push_back(sizes[5]);
+
+  return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+                                                 ArrayRef<OpFoldResult> offsets,
+                                                 ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value input = getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  Location loc = getLoc();
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  auto context = builder.getContext();
+  auto affineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
+
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(mappedOffset1);
+  sliceOffsets.push_back(mappedOffset2);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(sizes[4]);
+  sliceSizes.push_back(alphaHAttr);
+  sliceSizes.push_back(alphaWAttr);
+  sliceSizes.push_back(sizes[5]);
+  SmallVector<OpFoldResult> inputStrides(4, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  SmallVector<OpFoldResult> outputStrides(6, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // WinogradOutputTransformOp
 //===----------------------------------------------------------------------===//
@@ -2964,6 +3080,106 @@ LogicalResult WinogradOutputTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  auto indexType = builder.getIndexType();
+  auto zeroAttr = builder.getIntegerAttr(indexType, 0);
+  auto oneAttr = builder.getIntegerAttr(indexType, 1);
+  Value value = getValue();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zeroAttr;
+    loopBounds[dim].size = getDimValue(builder, loc, value, dim);
+    loopBounds[dim].stride = oneAttr;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value output = getOutput();
+  auto outputType = cast<ShapedType>(output.getType());
+  auto outputShape = outputType.getShape();
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int64_t m = getM();
+  auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
+  auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
+
+  Location loc = getLoc();
+  auto context = builder.getContext();
+  auto affineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
+
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(mappedOffset1);
+  resultOffsets.push_back(mappedOffset2);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(sizes[4]);
+  resultSizes.push_back(heightM);
+  resultSizes.push_back(widthM);
+  resultSizes.push_back(sizes[5]);
+  return success();
+}
+
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Location loc = getLoc();
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(offsets[2]);
+  sliceOffsets.push_back(offsets[3]);
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(sizes[0]);
+  sliceSizes.push_back(sizes[1]);
+  sliceSizes.push_back(oneAttr);
+  sliceSizes.push_back(oneAttr);
+  sliceSizes.push_back(sizes[4]);
+  sliceSizes.push_back(sizes[5]);
+  SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  SmallVector<OpFoldResult> strides(4, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9baf358a95503c..325dc04a828f09 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3853,6 +3853,47 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  FailureOr<Operation *> maybeTransformed = failure();
+  bool supported =
+      TypeSwitch<Operation *, bool>(target)
+          .Case([&](linalg::WinogradFilterTransformOp op) {
+            maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
+            return true;
+          })
+          .Case([&](linalg::WinogradInputTransformOp op) {
+            maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
+            return true;
+          })
+          .Case([&](linalg::WinogradOutputTransformOp op) {
+            maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
+            return true;
+          })
+          .Default([&](Operation *op) { return false; });
+
+  if (!supported) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "this operation is not supported to decompose into other operations";
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+
+  if (supported && failed(maybeTransformed)) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "decompose Winograd operations failed";
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index c6c770e2781ff0..c01a5f4b055f1a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -1169,6 +1169,24 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
   return winogradConv2DHelper(rewriter, op, m, r);
 }
 
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradFilterTransformOp op) {
+  return decomposeWinogradFilterTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+                                  linalg::WinogradInputTransformOp op) {
+  return decomposeWinogradInputTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradOutputTransformOp op) {
+  return decomposeWinogradOutputTransformHelper(rewriter, op);
+}
+
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
new file mode 100644
index 00000000000000..29833324bb21bc
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -0,0 +1,260 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<6x6x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %2 = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+  %4 = tensor.empty() : tensor<36x8x2xf32>
+  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %6 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-DAG:    %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK:        scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:     scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG3]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK:            linalg.matmul
+// CHECK:            linalg.matmul
+// CHECK:        %[[S5:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK:        %[[S6:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK:        scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:     scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK:            %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:            %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:            tensor.extract_slice %[[ARG0]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
+// CHECK:            %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK:            %[[S15:.*]] = scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK-NEXt:         scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG8]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK:                linalg.matmul
+// CHECK:                linalg.matmul
+// CHECK:          %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
+// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
+// CHECK:       %[[S9:.*]] = linalg.batch_matmul
+// CHECK:       %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK:   %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK:   scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:     scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK:       tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK:       %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:       %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
+// CHECK:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:         scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK:           linalg.matmul
+// CHECK:           linalg.matmul
+// CHECK:       %[[S16:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:       %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
+// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<6x6x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+  ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
+    tensor.yield %cst : f32
+  } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+  %2 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+  %4 = tensor.empty() : tensor<36x18x2xf32>
+  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+  %padded_1 = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+  ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
+    tensor.yield %cst : f32
+  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %extracted_slice = tensor.extract_slice %6[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+  return %extracted_slice : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK:   scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK:       %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK:       tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK:       linalg.matmul
+// CHECK:       %[[S13:.*]] = linalg.matmul
+// CHECK:       %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK:       %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK:       scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x5x2xf32>
+// CHECK:   %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK:   ^bb0({{.*}}):
+// CHECK:     tensor.yield %[[CST_6]] : f32
+// CHECK:   } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK:   %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK:   scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK:       %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:       %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:       tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32>
+// CHECK:       %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK:       %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK:         scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK:           linalg.matmul
+// CHECK:           %[[S17:.*]] = linalg.matmul
+// CHECK:           %[[S18:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+// CHECK:           %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S17]] into %[[S18]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x1x1xf32>
+// CHECK:           %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x1x1x2x5xf32>
+// CHECK:           scf.yield %[[INSERTED_SLICE_14]] : tensor<6x6x1x1x2x5xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x3x3x2x5xf32>
+// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
+// CHECK:   %[[S6:.*]] = linalg.batch_matmul
+// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK:   %[[PADDED_8:.*]] = tensor.pad %[[ARG3]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK:   ^bb0({{.*}}):
+// CHECK:     tensor.yield %[[CST_6]] : f32
+// CHECK:   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+// CHECK:   %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK:   %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK:       tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK:       %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:       %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:       %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32>
+// CHECK:       %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK:         scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK:           linalg.matmul
+// CHECK:           linalg.matmul
+// CHECK:           %[[S22:.*]] = linalg.mul
+// CHECK:           %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK:           %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK:           %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK:           scf.yield %[[INSERTED_SLICE_14]] : tensor<2x4x4x2xf32>
+// CHECK:       %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:       %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32>
+// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
+// CHECK:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK: }
+
+// -----
+
+func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+  %0 = tensor.empty() : tensor<6x1x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+  %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+  %4 = tensor.empty() : tensor<6x2x2xf32>
+  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %6 : tensor<2x4x1x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK:  %[[C5:.*]] = arith.constant 5 : index
+// CHECK:  %[[C2:.*]] = arith.constant 2 : index
+// CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
+// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<1x3x1x1xf32> to tensor<3x1xf32>
+// CHECK:      %[[S9:.*]] = linalg.matmul
+// CHECK:      %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
+// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x5x2xf32>
+// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<2x6x1x5xf32> to tensor<1x6x1x1xf32>
+// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> to tensor<6x1xf32>
+// CHECK:      %[[S9:.*]] = linalg.matmul
+// CHECK:      %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> into tensor<6x1x1x1x2x5xf32>
+// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x1x1x2x5xf32>
+// CHECK:  %[[S5:.*]] = linalg.batch_matmul
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x2x2xf32> to tensor<6x1x1x1x1x1xf32>
+// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> to tensor<6x1xf32>
+// CHECK:      linalg.matmul
+// CHECK:      %[[S12:.*]] = linalg.mul
+// CHECK:      %[[S13:.*]] = tensor.empty() : tensor<1x4x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[S13]][0, 0, 0, 0] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<4x1xf32> into tensor<1x4x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<1x4x1x1xf32> into tensor<2x4x1x2xf32>
+// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<2x4x1x2xf32>

>From 0c121bbef5f970a6ebc632c4d791ba535f2c29f7 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 17 Jul 2024 18:26:14 +0100
Subject: [PATCH 2/7] Address ftynse's comments

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 122 +++----
 .../transform-tile-and-winograd-rewrite.mlir  | 332 ++++++++++--------
 2 files changed, 226 insertions(+), 228 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 47800dfc1d1a2a..4905e3f8283f89 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2885,15 +2885,6 @@ LogicalResult WinogradFilterTransformOp::verify() {
 // WinogradInputTransformOp
 //===----------------------------------------------------------------------===//
 
-Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
-                               Location loc) {
-  if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
-    auto intAttr = cast<IntegerAttr>(attr);
-    return builder.create<arith::ConstantOp>(loc, intAttr);
-  }
-  return opFoldResult.get<Value>();
-}
-
 LogicalResult WinogradInputTransformOp::verify() {
   auto inputType = cast<ShapedType>(getInput().getType());
   ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2934,9 +2925,9 @@ LogicalResult WinogradInputTransformOp::verify() {
 SmallVector<Range>
 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
   Location loc = getLoc();
-  auto indexType = builder.getIndexType();
-  auto zeroAttr = builder.getIntegerAttr(indexType, 0);
-  auto oneAttr = builder.getIntegerAttr(indexType, 1);
+  IndexType indexType = builder.getIndexType();
+  IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
+  IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
   Value output = getOutput();
   SmallVector<Range> loopBounds(6);
   for (unsigned dim = 0; dim < 6; ++dim) {
@@ -2958,21 +2949,13 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
-  auto zeroAttr = builder.getI64IntegerAttr(0);
-  auto oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
 
-  resultOffsets.push_back(zeroAttr);
-  resultOffsets.push_back(zeroAttr);
-  resultOffsets.push_back(offsets[2]);
-  resultOffsets.push_back(offsets[3]);
-  resultOffsets.push_back(zeroAttr);
-  resultOffsets.push_back(zeroAttr);
-  resultSizes.push_back(sizes[0]);
-  resultSizes.push_back(sizes[1]);
-  resultSizes.push_back(oneAttr);
-  resultSizes.push_back(oneAttr);
-  resultSizes.push_back(sizes[4]);
-  resultSizes.push_back(sizes[5]);
+  resultOffsets.append(
+      {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+  resultSizes.append(
+      {sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
 
   return success();
 }
@@ -2981,11 +2964,11 @@ FailureOr<TilingResult>
 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
-  auto oneAttr = builder.getI64IntegerAttr(1);
-  auto zeroAttr = builder.getI64IntegerAttr(0);
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
   Value input = getInput();
   auto inputType = cast<ShapedType>(input.getType());
-  auto inputShape = inputType.getShape();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
   int64_t inputH = inputShape[1];
   int64_t inputW = inputShape[2];
   int64_t m = getM();
@@ -2993,29 +2976,25 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   int64_t alpha = m + r - 1;
   int64_t alphaH = inputH != 1 ? alpha : 1;
   int64_t alphaW = inputW != 1 ? alpha : 1;
-  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
-  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
   Location loc = getLoc();
   SmallVector<Value> tiledOperands;
   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
-  auto context = builder.getContext();
+  MLIRContext *context = builder.getContext();
   auto affineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
   Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+      loc, affineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
   Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
-
-  sliceOffsets.push_back(zeroAttr);
-  sliceOffsets.push_back(mappedOffset1);
-  sliceOffsets.push_back(mappedOffset2);
-  sliceOffsets.push_back(zeroAttr);
-  sliceSizes.push_back(sizes[4]);
-  sliceSizes.push_back(alphaHAttr);
-  sliceSizes.push_back(alphaWAttr);
-  sliceSizes.push_back(sizes[5]);
+      loc, affineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+
+  sliceOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
+  sliceSizes.append({sizes[4], alphaHAttr, alphaWAttr, sizes[5]});
   SmallVector<OpFoldResult> inputStrides(4, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
@@ -3030,7 +3009,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
 
-  SmallVector<Type, 4> resultTypes;
+  SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
@@ -3083,9 +3062,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
 SmallVector<Range>
 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
   Location loc = getLoc();
-  auto indexType = builder.getIndexType();
-  auto zeroAttr = builder.getIntegerAttr(indexType, 0);
-  auto oneAttr = builder.getIntegerAttr(indexType, 1);
+  IndexType indexType = builder.getIndexType();
+  IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
+  IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
   Value value = getValue();
   SmallVector<Range> loopBounds(6);
   for (unsigned dim = 0; dim < 6; ++dim) {
@@ -3107,57 +3086,44 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
-  auto zeroAttr = builder.getI64IntegerAttr(0);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
   Value output = getOutput();
   auto outputType = cast<ShapedType>(output.getType());
-  auto outputShape = outputType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
   int64_t outputH = outputShape[1];
   int64_t outputW = outputShape[2];
   int64_t m = getM();
-  auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
-  auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
+  IntegerAttr heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
+  IntegerAttr widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
 
   Location loc = getLoc();
-  auto context = builder.getContext();
+  MLIRContext *context = builder.getContext();
   auto affineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
   Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+      loc, affineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
   Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
-
-  resultOffsets.push_back(zeroAttr);
-  resultOffsets.push_back(mappedOffset1);
-  resultOffsets.push_back(mappedOffset2);
-  resultOffsets.push_back(zeroAttr);
-  resultSizes.push_back(sizes[4]);
-  resultSizes.push_back(heightM);
-  resultSizes.push_back(widthM);
-  resultSizes.push_back(sizes[5]);
+      loc, affineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+
+  resultOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
+  resultSizes.append({sizes[4], heightM, widthM, sizes[5]});
   return success();
 }
 
 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
     OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes) {
-  auto oneAttr = builder.getI64IntegerAttr(1);
-  auto zeroAttr = builder.getI64IntegerAttr(0);
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
   Location loc = getLoc();
   SmallVector<Value> tiledOperands;
   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
-  sliceOffsets.push_back(zeroAttr);
-  sliceOffsets.push_back(zeroAttr);
-  sliceOffsets.push_back(offsets[2]);
-  sliceOffsets.push_back(offsets[3]);
-  sliceOffsets.push_back(zeroAttr);
-  sliceOffsets.push_back(zeroAttr);
-  sliceSizes.push_back(sizes[0]);
-  sliceSizes.push_back(sizes[1]);
-  sliceSizes.push_back(oneAttr);
-  sliceSizes.push_back(oneAttr);
-  sliceSizes.push_back(sizes[4]);
-  sliceSizes.push_back(sizes[5]);
+  sliceOffsets.append(
+      {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+  sliceSizes.append({sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
   SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
@@ -3172,7 +3138,7 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getOutput(), sliceOffsets, sliceSizes, strides));
 
-  SmallVector<Type, 4> resultTypes;
+  SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 29833324bb21bc..6bb3fb1423edc6 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -31,52 +31,77 @@ module attributes {transform.with_named_sequence} {
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func.func @conv2d
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK-DAG:    %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK:        scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK-NEXT:     scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG3]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK:            linalg.matmul
-// CHECK:            linalg.matmul
-// CHECK:        %[[S5:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:        %[[S6:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:        scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<6x6x2x2x2x5xf32>) {
-// CHECK-NEXT:     scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x2x2x5xf32>) {
-// CHECK:            %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:            %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:            tensor.extract_slice %[[ARG0]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
-// CHECK:            %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK:            %[[S15:.*]] = scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK-NEXt:         scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG8]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK:                linalg.matmul
-// CHECK:                linalg.matmul
-// CHECK:          %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
-// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
-// CHECK:       %[[S9:.*]] = linalg.batch_matmul
-// CHECK:       %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK:   %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK:   scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
-// CHECK-NEXT:     scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
-// CHECK:       tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
-// CHECK:       %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:       %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
-// CHECK:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:         scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK:           linalg.matmul
-// CHECK:           linalg.matmul
-// CHECK:       %[[S16:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:       %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
-// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+// CHECK:  %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK:  %[[C5:.*]] = arith.constant 5 : index
+// CHECK:  %[[C2:.*]] = arith.constant 2 : index
+// CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[S0:.*]] = tensor.empty()
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK:      %[[S11:.*]] = linalg.matmul
+// CHECK:      %[[S13:.*]] = linalg.matmul
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]]
+// CHECK:    scf.yield %[[S9]]
+// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK:        %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK:          %[[S15:.*]] = linalg.matmul
+// CHECK:          %[[S17:.*]] = linalg.matmul
+// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S17]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
+// CHECK:        scf.yield %[[S13]]
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]]
+// CHECK:    scf.yield %[[S9]]
+// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[S6:.*]] = linalg.batch_matmul
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
+// CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK:        %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S17:.*]] = linalg.matmul
+// CHECK:          %[[S19:.*]] = linalg.matmul
+// CHECK:          %[[S20:.*]] = tensor.empty()
+// CHECK:          %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:            linalg.yield %[[IN]] : f32
+// CHECK:          } -> tensor<4x4xf32>
+// CHECK:          %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
+// CHECK:        scf.yield %[[S15]]
+// CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]]
+// CHECK:    scf.yield %[[S9]]
 
 // -----
 
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<6x6x5x2xf32>
   %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
@@ -91,7 +116,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
   %4 = tensor.empty() : tensor<36x18x2xf32>
   %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
   %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
-  %padded_1 = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+  %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
   ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
     tensor.yield %cst : f32
   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
@@ -117,80 +142,82 @@ module attributes {transform.with_named_sequence} {
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK:   scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK:       %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
-// CHECK:       tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
-// CHECK:       linalg.matmul
-// CHECK:       %[[S13:.*]] = linalg.matmul
-// CHECK:       %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
-// CHECK:       %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
-// CHECK:       scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x5x2xf32>
-// CHECK:   %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
-// CHECK:   ^bb0({{.*}}):
-// CHECK:     tensor.yield %[[CST_6]] : f32
-// CHECK:   } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
-// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:   %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:   scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK:       %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:       %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:       tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32>
-// CHECK:       %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK:       %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK:         scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK:           linalg.matmul
-// CHECK:           %[[S17:.*]] = linalg.matmul
-// CHECK:           %[[S18:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
-// CHECK:           %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S17]] into %[[S18]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x1x1xf32>
-// CHECK:           %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x1x1x2x5xf32>
-// CHECK:           scf.yield %[[INSERTED_SLICE_14]] : tensor<6x6x1x1x2x5xf32>
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x3x3x2x5xf32>
-// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
-// CHECK:   %[[S6:.*]] = linalg.batch_matmul
-// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
-// CHECK:   %[[PADDED_8:.*]] = tensor.pad %[[ARG3]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
-// CHECK:   ^bb0({{.*}}):
-// CHECK:     tensor.yield %[[CST_6]] : f32
-// CHECK:   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK:   %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK:   %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK:       tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6x1x1x2x2xf32>
-// CHECK:       %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:       %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:       %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32>
-// CHECK:       %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK:         scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK:           linalg.matmul
-// CHECK:           linalg.matmul
-// CHECK:           %[[S22:.*]] = linalg.mul
-// CHECK:           %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
-// CHECK:           %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
-// CHECK:           %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
-// CHECK:           scf.yield %[[INSERTED_SLICE_14]] : tensor<2x4x4x2xf32>
-// CHECK:       %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:       %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32>
-// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
-// CHECK:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK: }
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK:  %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK:  %[[C3:.*]] = arith.constant 3 : index
+// CHECK:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK:  %[[C5:.*]] = arith.constant 5 : index
+// CHECK:  %[[C2:.*]] = arith.constant 2 : index
+// CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[S0:.*]] = tensor.empty()
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK:      %[[S11:.*]] = linalg.matmul
+// CHECK:      %[[S13:.*]] = linalg.matmul
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
+// CHECK:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK:        %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK:          %[[S15:.*]] = linalg.matmul
+// CHECK:          %[[S17:.*]] = linalg.matmul
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S17]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
+// CHECK:        scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]]
+// CHECK:    scf.yield %[[S9]]
+// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[S6:.*]] = linalg.batch_matmul
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
+// CHECK:  %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK:  %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK:        %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S17:.*]] = linalg.matmul
+// CHECK:          %[[S19:.*]] = linalg.matmul
+// CHECK:          %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK:          %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:            linalg.yield %[[IN]] : f32
+// CHECK:          } -> tensor<4x4xf32>
+// CHECK:          %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S22]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_12]]
+// CHECK:        scf.yield %[[S15]] : tensor<2x4x4x2xf32>
+// CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]]
+// CHECK:    scf.yield %[[S9]]
+// CHECK:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
+// CHECK:  return %[[EXTRACTED_SLICE]]
 
 // -----
 
-func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
   %0 = tensor.empty() : tensor<6x1x5x2xf32>
   %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
   %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
@@ -200,7 +227,7 @@ func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
   %4 = tensor.empty() : tensor<6x2x2xf32>
   %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
   %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
   return %6 : tensor<2x4x1x2xf32>
 }
 
@@ -220,41 +247,46 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func.func @conv2d_mx1_rx1
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK:  %[[C1:.*]] = arith.constant 1 : index
-// CHECK:  %[[C5:.*]] = arith.constant 5 : index
-// CHECK:  %[[C2:.*]] = arith.constant 2 : index
-// CHECK:  %[[C0:.*]] = arith.constant 0 : index
-// CHECK:  %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x1x5x2xf32>) {
-// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x5x2xf32>) {
-// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
-// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<1x3x1x1xf32> to tensor<3x1xf32>
-// CHECK:      %[[S9:.*]] = linalg.matmul
-// CHECK:      %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
-// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x5x2xf32>
-// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]]) -> (tensor<6x1x1x1x2x5xf32>) {
-// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x1x1x2x5xf32>) {
-// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<2x6x1x5xf32> to tensor<1x6x1x1xf32>
-// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> to tensor<6x1xf32>
-// CHECK:      %[[S9:.*]] = linalg.matmul
-// CHECK:      %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> into tensor<6x1x1x1x2x5xf32>
-// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x1x1x2x5xf32>
-// CHECK:  %[[S5:.*]] = linalg.batch_matmul
-// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<2x4x1x2xf32>) {
-// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x4x1x2xf32>) {
-// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x2x2xf32> to tensor<6x1x1x1x1x1xf32>
-// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> to tensor<6x1xf32>
-// CHECK:      linalg.matmul
-// CHECK:      %[[S12:.*]] = linalg.mul
-// CHECK:      %[[S13:.*]] = tensor.empty() : tensor<1x4x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[S13]][0, 0, 0, 0] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<4x1xf32> into tensor<1x4x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<1x4x1x1xf32> into tensor<2x4x1x2xf32>
-// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<2x4x1x2xf32>
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S9:.*]] = linalg.matmul
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE]]
+// CHECK:     scf.yield %[[S7]]
+// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
+// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S9:.*]] = linalg.matmul
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE]]
+// CHECK:     scf.yield %[[S7]]
+// CHECK:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK:   %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:   %[[S5:.*]] = linalg.batch_matmul
+// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
+// CHECK:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:       %[[S9:.*]] = linalg.matmul
+// CHECK:       %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK:       %[[S11:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S10]] : tensor<4x1xf32>) {
+// CHECK:       ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:         linalg.yield %[[IN]] : f32
+// CHECK:       } -> tensor<4x1xf32>
+// CHECK:       %[[S12:.*]] = linalg.mul ins(%[[S11]], %[[S9]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE]]
+// CHECK:     scf.yield %[[S7]]
+// CHECK:   return %[[S6]]

>From b4e5c7c97b770391c81839261d47956925e23daf Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 18 Jul 2024 06:52:17 +0100
Subject: [PATCH 3/7] Address more comments

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 18 ++++++++----------
 1 file changed, 8 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4905e3f8283f89..da55360c82eb0b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2999,15 +2999,14 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
 
-  sliceOffsets.clear();
-  sliceSizes.clear();
-  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
-                                   sliceSizes)))
+  SmallVector<OpFoldResult> resultOffsets, resultSizes;
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+                                   resultSizes)))
     return failure();
 
   SmallVector<OpFoldResult> outputStrides(6, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
+      loc, getOutput(), resultOffsets, resultSizes, outputStrides));
 
   SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
@@ -3128,15 +3127,14 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
 
-  sliceOffsets.clear();
-  sliceSizes.clear();
-  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
-                                   sliceSizes)))
+  SmallVector<OpFoldResult> resultOffsets, resultSizes;
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+                                   resultSizes)))
     return failure();
 
   SmallVector<OpFoldResult> strides(4, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+      loc, getOutput(), resultOffsets, resultSizes, strides));
 
   SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());

>From e972232669e6de1638a86b62808cff7290ff3e70 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 18 Jul 2024 07:24:42 +0100
Subject: [PATCH 4/7] Update comments

---
 .../Dialect/Linalg/Transforms/Transforms.h    | 52 ++++++++++++-------
 1 file changed, 32 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 89f907cfc7ade7..861e14d22d9625 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1318,8 +1318,8 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
 
 /// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
 /// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
-/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
-/// C. After the rewriting, we get
+/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
+/// the rewriting, we get
 ///
 /// scf.for %f = lo_f to hi_f step 1
 ///   scf.for %c = lo_c to hi_c step 1
@@ -1333,30 +1333,42 @@ decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
 
 /// Rewrite linalg.winograd_input_transform. The data layout of the input is
 /// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
-/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
-/// C. After the rewriting, we get
-///
-/// scf.for %n = lo_n to hi_n step 1
-///   scf.for %c = lo_c to hi_c step 1
-///     %extracted = extract input<h x w> from input<n x h x w x c>
-///     %ret = linalg.matmul BT, %extracted
-///     %ret = linalg.matmul %ret, B
-///     %inserted = insert %ret into input<h x w x n x c>
+/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
+/// and tileW. After the rewriting, we get
+///
+/// scf.for %h = 0 to tileH step 1
+///   scf.for %w = 0 to tileW step 1
+///     scf.for %n = 0 to N step 1
+///       scf.for %c = 0 to C step 1
+///         %extracted = extract %extracted<alphaH x alphaW> from
+///                              %input<N x H x W x C>
+///                              at [%n, (%h x m), (%w x m), %c]
+///         %ret = linalg.matmul BT, %extracted
+///         %ret = linalg.matmul %ret, B
+///         %inserted = insert %ret<alphaH x alphaW> into
+///                            %output<alphaH x alphaW x tileH x tileW x N x C>
+///                            at [0, 0, %h, %w, %n, %c]
 FailureOr<Operation *>
 decomposeWinogradInputTransformOp(RewriterBase &rewriter,
                                   linalg::WinogradInputTransformOp op);
 
 /// Rewrite linalg.winograd_output_transform. The data layout of the output is
 /// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
-/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
-/// F. After the transformation, we get
-///
-/// scf.for %n = lo_n to hi_n step 1
-///   scf.for %f = lo_f to hi_f step 1
-///     %extracted = extract input<h x w> from result<h x w x n x f>
-///     %ret = linalg.matmul AT, %extracted
-///     %ret = linalg.matmul %ret, A
-///     %inserted = insert %ret into ret<n x h x w x f>
+/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
+/// and tileW. After the transformation, we get
+///
+/// scf.for %h = 0 to tileH step 1
+///   scf.for %w = 0 to tileW step 1
+///     scf.for %n = 0 to N step 1
+///       scf.for %f = 0 to F step 1
+///         %extracted = extract %extracted<alphaH x alphaW> from
+///                              %input<alphaH x alphaW x tileH x tileW x N x F>
+///                              at [0, 0, %h, %w, %n, %f]
+///         %ret = linalg.matmul AT, %extracted
+///         %ret = linalg.matmul %ret, A
+///         %inserted = insert %ret<alphaH x alphaW> into
+///                            output<N x H x W x F>
+///                            at [%n, (%h x m), (%w x m), %f]
 FailureOr<Operation *>
 decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
                                    linalg::WinogradOutputTransformOp op);

>From 73696eec36aa8cebc80cf093b995a7fb34064951 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 1 Aug 2024 08:44:10 +0100
Subject: [PATCH 5/7] address Max191's comments

---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  49 ++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 279 ++++++++++---
 .../Linalg/Transforms/WinogradConv2D.cpp      |   7 -
 .../transform-tile-and-winograd-rewrite.mlir  |  12 +-
 .../Linalg/transform-tile-winograd.mlir       | 392 ++++++++++++++++++
 5 files changed, 664 insertions(+), 75 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/transform-tile-winograd.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index bc2df7082ef8e3..aaf243e0157267 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -155,7 +155,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
 }
 
 def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
-    [AllElementTypesMatch<["filter", "output"]>]> {
+    [AllElementTypesMatch<["filter", "output"]>,
+     DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Winograd filter transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -190,6 +195,20 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
   }];
+  let extraClassDeclaration = [{
+    ShapedType getFilterOperandType() {
+      return cast<ShapedType>(getFilter().getType());
+    }
+    ShapedType getOutputOperandType() {
+      return cast<ShapedType>(getOutput().getType());
+    }
+    int64_t getFilterOperandRank() {
+      return getFilterOperandType().getRank();
+    }
+    int64_t getOutputOperandRank() {
+      return getOutputOperandType().getRank();
+    }
+  }];
   let hasVerifier = 1;
 }
 
@@ -234,6 +253,20 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
   }];
+  let extraClassDeclaration = [{
+    ShapedType getInputOperandType() {
+      return cast<ShapedType>(getInput().getType());
+    }
+    ShapedType getOutputOperandType() {
+      return cast<ShapedType>(getOutput().getType());
+    }
+    int64_t getInputOperandRank() {
+      return getInputOperandType().getRank();
+    }
+    int64_t getOutputOperandRank() {
+      return getOutputOperandType().getRank();
+    }
+  }];
   let hasVerifier = 1;
 }
 
@@ -278,6 +311,20 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
   }];
+  let extraClassDeclaration = [{
+    ShapedType getValueOperandType() {
+      return cast<ShapedType>(getValue().getType());
+    }
+    ShapedType getOutputOperandType() {
+      return cast<ShapedType>(getOutput().getType());
+    }
+    int64_t getValueOperandRank() {
+      return getValueOperandType().getRank();
+    }
+    int64_t getOutputOperandRank() {
+      return getOutputOperandType().getRank();
+    }
+  }];
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index da55360c82eb0b..ee633a715cc9ec 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2881,6 +2881,105 @@ LogicalResult WinogradFilterTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  IntegerAttr zeroAttr = builder.getIndexAttr(0);
+  IntegerAttr oneAttr = builder.getIndexAttr(1);
+  Value filter = getFilter();
+  SmallVector<Range> loopBounds;
+  Range fRange;
+  fRange.offset = zeroAttr;
+  // 0 is the index of F in the filter.
+  fRange.size = getDimValue(builder, loc, filter, 0);
+  fRange.stride = oneAttr;
+  loopBounds.emplace_back(fRange);
+  Range cRange;
+  cRange.offset = zeroAttr;
+  // 3 is the index of C in the filter.
+  cRange.size = getDimValue(builder, loc, filter, 3);
+  cRange.stride = oneAttr;
+  loopBounds.emplace_back(cRange);
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradFilterTransformOp::getLoopIteratorTypes() {
+  int64_t filterRank = getFilterOperandRank();
+  SmallVector<utils::IteratorType> iteratorTypes(filterRank - 2,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradFilterTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  ShapedType filterType = getFilterOperandType();
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = filterH != 1 ? alpha : 1;
+  int64_t alphaW = filterW != 1 ? alpha : 1;
+  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  resultOffsets.append({zeroAttr, zeroAttr, offsets[1], offsets[0]});
+  resultSizes.append({alphaHAttr, alphaWAttr, sizes[1], sizes[0]});
+
+  return success();
+}
+
+/// Implement tiling for winograd_filter_transform
+/// The input of winograd_filter_transform is (F, KH, KW, C).
+/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
+/// Users can specify the tile sizes of F and C.
+/// `offsets` are the values for the offsets of F, C for one tile.
+/// `sizes` are the values for the sizes of F, C for one tile.
+FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  ShapedType filterType = getFilterOperandType();
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
+  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  sliceOffsets.append({offsets[0], zeroAttr, zeroAttr, offsets[1]});
+  sliceSizes.append({sizes[0], filterHAttr, filterWAttr, sizes[1]});
+  int64_t filterRank = getFilterOperandRank();
+  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
+  Location loc = getLoc();
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
+
+  SmallVector<OpFoldResult> resultOffsets, resultSizes;
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+                                   resultSizes)))
+    return failure();
+
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+
+  SmallVector<Type> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // WinogradInputTransformOp
 //===----------------------------------------------------------------------===//
@@ -2925,22 +3024,23 @@ LogicalResult WinogradInputTransformOp::verify() {
 SmallVector<Range>
 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
   Location loc = getLoc();
-  IndexType indexType = builder.getIndexType();
-  IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
-  IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
+  IntegerAttr zeroAttr = builder.getIndexAttr(0);
+  IntegerAttr oneAttr = builder.getIndexAttr(1);
   Value output = getOutput();
-  SmallVector<Range> loopBounds(6);
-  for (unsigned dim = 0; dim < 6; ++dim) {
-    loopBounds[dim].offset = zeroAttr;
-    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
-    loopBounds[dim].stride = oneAttr;
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<Range> loopBounds(outputRank - 2);
+  for (unsigned dim = 2; dim < outputRank; ++dim) {
+    loopBounds[dim - 2].offset = zeroAttr;
+    loopBounds[dim - 2].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim - 2].stride = oneAttr;
   }
   return loopBounds;
 }
 
 SmallVector<utils::IteratorType>
 WinogradInputTransformOp::getLoopIteratorTypes() {
-  SmallVector<utils::IteratorType> iteratorTypes(6,
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<utils::IteratorType> iteratorTypes(outputRank - 2,
                                                  utils::IteratorType::parallel);
   return iteratorTypes;
 }
@@ -2950,52 +3050,79 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  ShapedType inputType = getInputOperandType();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
   resultOffsets.append(
-      {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+      {zeroAttr, zeroAttr, offsets[0], offsets[1], offsets[2], offsets[3]});
   resultSizes.append(
-      {sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
+      {alphaHAttr, alphaWAttr, sizes[0], sizes[1], sizes[2], sizes[3]});
 
   return success();
 }
 
+/// Implement tiling for winograd_input_transform
+/// The input of winograd_input_transform is (N, H, W, C).
+/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
+/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
+/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
+/// the values for the sizes of tileH, tileW, N, C for one tile.
 FailureOr<TilingResult>
 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  Value input = getInput();
-  auto inputType = cast<ShapedType>(input.getType());
+  ShapedType inputType = getInputOperandType();
   ArrayRef<int64_t> inputShape = inputType.getShape();
   int64_t inputH = inputShape[1];
   int64_t inputW = inputShape[2];
   int64_t m = getM();
   int64_t r = getR();
-  int64_t alpha = m + r - 1;
-  int64_t alphaH = inputH != 1 ? alpha : 1;
-  int64_t alphaW = inputW != 1 ? alpha : 1;
-  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
-  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
   Location loc = getLoc();
-  SmallVector<Value> tiledOperands;
-  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
-
   MLIRContext *context = builder.getContext();
-  auto affineMap =
+  auto offsetAffineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
-  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+  Value mappedOffsetH = builder.create<affine::AffineApplyOp>(
+      loc, offsetAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[0]));
+  Value mappedOffsetW = builder.create<affine::AffineApplyOp>(
+      loc, offsetAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[1]));
+  auto sizeAffineMap = AffineMap::get(
+      1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
+  Value mappedSizeH = builder.create<affine::AffineApplyOp>(
+      loc, sizeAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, sizes[0]));
+  Value mappedSizeW = builder.create<affine::AffineApplyOp>(
+      loc, sizeAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, sizes[1]));
+
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
-  sliceOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
-  sliceSizes.append({sizes[4], alphaHAttr, alphaWAttr, sizes[5]});
-  SmallVector<OpFoldResult> inputStrides(4, oneAttr);
+  OpFoldResult offsetH =
+      inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
+  OpFoldResult offsetW =
+      inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
+  sliceOffsets.append({offsets[2], offsetH, offsetW, offsets[3]});
+  OpFoldResult sizeH =
+      inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+  OpFoldResult sizeW =
+      inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+  sliceSizes.append({sizes[2], sizeH, sizeW, sizes[3]});
+  int64_t inputRank = getInputOperandRank();
+  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
 
@@ -3004,7 +3131,8 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                    resultSizes)))
     return failure();
 
-  SmallVector<OpFoldResult> outputStrides(6, oneAttr);
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getOutput(), resultOffsets, resultSizes, outputStrides));
 
@@ -3032,7 +3160,8 @@ LogicalResult WinogradOutputTransformOp::verify() {
   bool leftTransform = valueH != 1;
   bool rightTransform = valueW != 1;
 
-  SmallVector<int64_t> expectedOutputShape(4, valueH);
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
   if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
     expectedOutputShape[1] = ShapedType::kDynamic;
   } else {
@@ -3061,22 +3190,23 @@ LogicalResult WinogradOutputTransformOp::verify() {
 SmallVector<Range>
 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
   Location loc = getLoc();
-  IndexType indexType = builder.getIndexType();
-  IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
-  IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
+  IntegerAttr zeroAttr = builder.getIndexAttr(0);
+  IntegerAttr oneAttr = builder.getIndexAttr(1);
   Value value = getValue();
-  SmallVector<Range> loopBounds(6);
-  for (unsigned dim = 0; dim < 6; ++dim) {
-    loopBounds[dim].offset = zeroAttr;
-    loopBounds[dim].size = getDimValue(builder, loc, value, dim);
-    loopBounds[dim].stride = oneAttr;
+  int64_t valueRank = getValueOperandRank();
+  SmallVector<Range> loopBounds(valueRank - 2);
+  for (unsigned dim = 2; dim < valueRank; ++dim) {
+    loopBounds[dim - 2].offset = zeroAttr;
+    loopBounds[dim - 2].size = getDimValue(builder, loc, value, dim);
+    loopBounds[dim - 2].stride = oneAttr;
   }
   return loopBounds;
 }
 
 SmallVector<utils::IteratorType>
 WinogradOutputTransformOp::getLoopIteratorTypes() {
-  SmallVector<utils::IteratorType> iteratorTypes(6,
+  int64_t valueRank = getValueOperandRank();
+  SmallVector<utils::IteratorType> iteratorTypes(valueRank - 2,
                                                  utils::IteratorType::parallel);
   return iteratorTypes;
 }
@@ -3085,32 +3215,49 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
-  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  Value output = getOutput();
-  auto outputType = cast<ShapedType>(output.getType());
-  ArrayRef<int64_t> outputShape = outputType.getShape();
-  int64_t outputH = outputShape[1];
-  int64_t outputW = outputShape[2];
   int64_t m = getM();
-  IntegerAttr heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
-  IntegerAttr widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
 
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
   auto affineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+  Value mappedOffsetH = builder.create<affine::AffineApplyOp>(
       loc, affineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
-  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[0]));
+  Value mappedOffsetW = builder.create<affine::AffineApplyOp>(
       loc, affineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[1]));
+  Value mappedSizeH = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueOrCreateConstantIndexOp(builder, loc, sizes[0]));
+  Value mappedSizeW = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueOrCreateConstantIndexOp(builder, loc, sizes[1]));
 
-  resultOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
-  resultSizes.append({sizes[4], heightM, widthM, sizes[5]});
+  ShapedType valueType = getValueOperandType();
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  OpFoldResult offsetH =
+      valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
+  OpFoldResult offsetW =
+      valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
+  OpFoldResult sizeH =
+      valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+  OpFoldResult sizeW =
+      valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+
+  resultOffsets.append({offsets[2], offsetH, offsetW, offsets[3]});
+  resultSizes.append({sizes[2], sizeH, sizeW, sizes[3]});
   return success();
 }
 
+/// Implement tiling for winograd_output_transform
+/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
+/// F). The output of winograd_output_transform is (N, H, W, F) Users can
+/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
+/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
+/// for the sizes of tileH, tileW, N, F for one tile.
 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
     OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes) {
@@ -3120,10 +3267,19 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
   SmallVector<Value> tiledOperands;
   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
+  ShapedType valueType = getValueOperandType();
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  int64_t alphaH = valueShape[0];
+  int64_t alphaW = valueShape[1];
+  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
   sliceOffsets.append(
-      {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
-  sliceSizes.append({sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
-  SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
+      {zeroAttr, zeroAttr, offsets[0], offsets[1], offsets[2], offsets[3]});
+  sliceSizes.append(
+      {alphaHAttr, alphaWAttr, sizes[0], sizes[1], sizes[2], sizes[3]});
+  int64_t valueRank = getValueOperandRank();
+  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
 
@@ -3132,7 +3288,8 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
                                    resultSizes)))
     return failure();
 
-  SmallVector<OpFoldResult> strides(4, oneAttr);
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getOutput(), resultOffsets, resultSizes, strides));
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index c01a5f4b055f1a..b65b18699a15aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -490,8 +490,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
   Type elementType = inputType.getElementType();
   auto inputShape = inputType.getShape(); // N, H, W, C
   int64_t inputN = inputShape[0];
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
   int64_t inputC = inputShape[3];
   auto valueType = cast<ShapedType>(retValue.getType());
   auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
@@ -500,11 +498,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
   int64_t alphaH = leftTransform ? m + r - 1 : 1;
   int64_t alphaW = rightTransform ? m + r - 1 : 1;
 
-  if ((inputH != (tileH * m) + (r - 1)) && inputH != 1)
-    return Value();
-  if ((inputW != (tileW * m) + (r - 1)) && inputW != 1)
-    return Value();
-
   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
                        ValueRange args) -> scf::ValueVector {
     Value tileHIter = ivs[0];
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 6bb3fb1423edc6..54eb755047f18b 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -18,9 +18,9 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
     %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
     %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
@@ -129,9 +129,9 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
     %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
     %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
@@ -235,9 +235,9 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
     %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
     %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
new file mode 100644
index 00000000000000..b1b2818db38fc7
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -0,0 +1,392 @@
+// RUN: mlir-opt %s -transform-interpreter --split-input-file | FileCheck %s
+
+func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  return %0 : tensor<6x6x5x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @tile_winograd_filter(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+// CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[C2:.*]] = arith.constant 2 : index
+// CHECK:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:    %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:    %[[C5:.*]] = arith.constant 5 : index
+// CHECK:    %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:    %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK:      %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
+// CHECK:      %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  return %0 : tensor<6x6x5x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (2, -d0 + 5)>
+// CHECK-LABEL: func.func @tile_winograd_filter(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:     %[[C5:.*]] = arith.constant 5 : index
+// CHECK:     %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C2_1]]
+// CHECK:       %[[C5_2:.*]] = arith.constant 5 : index
+// CHECK:       %[[S3:.*]] = affine.min #[[$MAP0]](%[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, %[[S3]]] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x?xf32>
+// CHECK:       %[[EXTRACTED_SLICE_3:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32>
+// CHECK:       %[[S4:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
+
+// -----
+
+func.func @tile_winograd_filter(%arg0: tensor<2x3x1x5xf32>, %arg1: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
+  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x1x5xf32>) outs(%arg1 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+  return %0 : tensor<6x1x5x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @tile_winograd_filter(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x1x5xf32>, %[[ARG1:.*]]: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:     %[[C5:.*]] = arith.constant 5 : index
+// CHECK:     %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
+// CHECK:       %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
+// CHECK:       %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  return %0 : tensor<6x6x2x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop3:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK:   %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK:   %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:   %[[C1_3:.*]] = arith.constant 1 : index
+// CHECK:   %[[S5:.*]] = affine.apply #[[$MAP1]](%[[C1_3]])
+// CHECK:   %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK:   %[[S6:.*]] = affine.apply #[[$MAP1]](%[[C1_4]])
+// CHECK:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x5xf32>
+// CHECK:   %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK:   %[[S7:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  return %0 : tensor<6x6x2x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK:   %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK:     %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK:     %[[C2_4:.*]] = arith.constant 2 : index
+// CHECK:     %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
+// CHECK:       %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK:       %[[C5:.*]] = arith.constant 5 : index
+// CHECK:       %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
+// CHECK:         %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK:         %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:         %[[C1_8:.*]] = arith.constant 1 : index
+// CHECK:         %[[S7:.*]] = affine.apply #[[$MAP1]](%[[C1_8]])
+// CHECK:         %[[C1_9:.*]] = arith.constant 1 : index
+// CHECK:         %[[S8:.*]] = affine.apply #[[$MAP1]](%[[C1_9]])
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], %[[S8]], 1] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<1x?x?x1xf32>
+// CHECK:         %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32>
+// CHECK:         %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  return %0 : tensor<6x6x2x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (2, -d0 + 5)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
+// CHECK:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C2_3:.*]] = arith.constant 2 : index
+// CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
+// CHECK:     %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK:     %[[C2_5:.*]] = arith.constant 2 : index
+// CHECK:     %[[C2_6:.*]] = arith.constant 2 : index
+// CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C2_5]] step %[[C2_6]]
+// CHECK:       %[[C0_7:.*]] = arith.constant 0 : index
+// CHECK:       %[[C5:.*]] = arith.constant 5 : index
+// CHECK:       %[[C2_8:.*]] = arith.constant 2 : index
+// CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_7]] to %[[C5]] step %[[C2_8]]
+// CHECK:         %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG8]])
+// CHECK:         %[[S6:.*]] = affine.apply #[[$MAP1]](%[[ARG2]])
+// CHECK:         %[[S7:.*]] = affine.apply #[[$MAP1]](%[[ARG4]])
+// CHECK:         %[[C2_10:.*]] = arith.constant 2 : index
+// CHECK:         %[[S8:.*]] = affine.apply #[[$MAP2]](%[[C2_10]])
+// CHECK:         %[[C2_11:.*]] = arith.constant 2 : index
+// CHECK:         %[[S9:.*]] = affine.apply #[[$MAP2]](%[[C2_11]])
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S6]], %[[S7]], %[[ARG8]]] [2, %[[S8]], %[[S9]], %[[S5]]] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x?xf32>
+// CHECK:         %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32>
+// CHECK:         %[[S10:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x1x10x5xf32>, %arg1: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x1x10x5xf32>) outs(%arg1 : tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32>
+  return %0 : tensor<1x6x1x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x1x10x5xf32>, %[[ARG1:.*]]: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[C1_0:.*]] = arith.constant 1 : index
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C1]] step %[[C1_0]]
+// CHECK:     %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK:     %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_2]]
+// CHECK:       %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK:       %[[C2_4:.*]] = arith.constant 2 : index
+// CHECK:       %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
+// CHECK:         %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK:         %[[C5:.*]] = arith.constant 5 : index
+// CHECK:         %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
+// CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:           %[[C1_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[S7:.*]] = affine.apply #[[$MAP1]](%[[C1_8]])
+// CHECK:           %[[C1_9:.*]] = arith.constant 1 : index
+// CHECK:           %[[S8:.*]] = affine.apply #[[$MAP1]](%[[C1_9]])
+// CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], 0, %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
+// CHECK:           %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
+// CHECK:           %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x2x2xf32>, %arg1: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x2x2xf32>) outs(%arg1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @tile_winograd_output(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<6x6x2x2x2x2xf32>, %[[ARG1:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:     %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK:     %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK:       %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK:       %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:       %[[C1_3:.*]] = arith.constant 1 : index
+// CHECK:       %[[S5:.*]] = affine.apply #[[$MAP0]](%[[C1_3]])
+// CHECK:       %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK:       %[[S6:.*]] = affine.apply #[[$MAP0]](%[[C1_4]])
+// CHECK:       %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32>
+
+// -----
+
+func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x3x5xf32>, %arg1: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
+  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x3x5xf32>) outs(%arg1 : tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32>
+  return %0 : tensor<3x8x8x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (2, -d0 + 3)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (2, -d0 + 5)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL:  func.func @tile_winograd_output(
+// CHECK-SAME:   %[[ARG0:.*]]: tensor<6x6x2x2x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
+// CHECK:    %[[S0:.*]] = tensor.empty() : tensor<3x8x8x5xf32>
+// CHECK:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK:    %[[C2:.*]] = arith.constant 2 : index
+// CHECK:    %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK:    %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
+// CHECK:      %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:      %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK:      %[[C2_3:.*]] = arith.constant 2 : index
+// CHECK:      %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
+// CHECK:        %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK:        %[[C3:.*]] = arith.constant 3 : index
+// CHECK:        %[[C2_5:.*]] = arith.constant 2 : index
+// CHECK:        %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]]
+// CHECK:          %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK:          %[[C5:.*]] = arith.constant 5 : index
+// CHECK:          %[[C2_7:.*]] = arith.constant 2 : index
+// CHECK:          %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]]
+// CHECK:            %[[C3_8:.*]] = arith.constant 3 : index
+// CHECK:            %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG6]])
+// CHECK:            %[[C5_9:.*]] = arith.constant 5 : index
+// CHECK:            %[[S6:.*]] = affine.min #[[$MAP1]](%[[ARG8]])
+// CHECK:            %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, %[[S5]], %[[S6]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x3x5xf32> to tensor<6x6x2x2x?x?xf32>
+// CHECK:            %[[S7:.*]] = affine.apply #[[$MAP2]](%[[ARG2]])
+// CHECK:            %[[S8:.*]] = affine.apply #[[$MAP2]](%[[ARG4]])
+// CHECK:            %[[C2_10:.*]] = arith.constant 2 : index
+// CHECK:            %[[S9:.*]] = affine.apply #[[$MAP2]](%[[C2_10]])
+// CHECK:            %[[C2_11:.*]] = arith.constant 2 : index
+// CHECK:            %[[S10:.*]] = affine.apply #[[$MAP2]](%[[C2_11]])
+// CHECK:            %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor<?x?x?x?xf32>
+
+// -----
+
+func.func @tile_winograd_output(%arg0 : tensor<6x1x2x1x3x5xf32>, %arg1: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
+  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x1x2x1x3x5xf32>) outs(%arg1 : tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32>
+  return %0 : tensor<3x8x1x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @tile_winograd_output(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<6x1x2x1x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:     %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:     %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]]
+// CHECK:       %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK:       %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]]
+// CHECK:         %[[C0_5:.*]] = arith.constant 0 : index
+// CHECK:         %[[C5:.*]] = arith.constant 5 : index
+// CHECK:         %[[C1_6:.*]] = arith.constant 1 : index
+// CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]]
+// CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32>
+// CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:           %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK:           %[[S7:.*]] = affine.apply #[[$MAP0]](%[[C1_7]])
+// CHECK:           %[[C1_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[S8:.*]] = affine.apply #[[$MAP0]](%[[C1_8]])
+// CHECK:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
+// CHECK:           %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>

>From 209b4e99388463245698ca0cc201a2a664d7719d Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Sun, 4 Aug 2024 02:36:04 +0100
Subject: [PATCH 6/7] Update test cases

---
 .../Linalg/transform-tile-winograd.mlir       | 189 +++++++++---------
 1 file changed, 94 insertions(+), 95 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
index b1b2818db38fc7..2cad4fa37530d6 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -15,13 +15,13 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: func.func @tile_winograd_filter(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
-// CHECK:  %[[C0:.*]] = arith.constant 0 : index
-// CHECK:  %[[C2:.*]] = arith.constant 2 : index
-// CHECK:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:  %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:  %[[C1_1:.*]] = arith.constant 1 : index
 // CHECK:  %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:    %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK:    %[[C5:.*]] = arith.constant 5 : index
-// CHECK:    %[[C1_1:.*]] = arith.constant 1 : index
 // CHECK:    %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
 // CHECK:      %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
@@ -42,16 +42,16 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (2, -d0 + 5)>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 5, 2)>
 // CHECK-LABEL: func.func @tile_winograd_filter(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
-// CHECK:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2_1:.*]] = arith.constant 2 : index
 // CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK:     %[[C5:.*]] = arith.constant 5 : index
-// CHECK:     %[[C2_1:.*]] = arith.constant 2 : index
 // CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C2_1]]
 // CHECK:       %[[C5_2:.*]] = arith.constant 5 : index
 // CHECK:       %[[S3:.*]] = affine.min #[[$MAP0]](%[[ARG4]])
@@ -76,13 +76,13 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: func.func @tile_winograd_filter(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x1x5xf32>, %[[ARG1:.*]]: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
-// CHECK:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_1:.*]] = arith.constant 1 : index
 // CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK:     %[[C5:.*]] = arith.constant 5 : index
-// CHECK:     %[[C1_1:.*]] = arith.constant 1 : index
 // CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
 // CHECK:       %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
@@ -107,13 +107,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
 // CHECK-LABEL: func.func @tile_winograd_input(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
 // CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK: %[[C2_1:.*]] = arith.constant 2 : index
-// CHECK: %[[C1_2:.*]] = arith.constant 1 : index
 // CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
 // CHECK:   %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
 // CHECK:   %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
@@ -144,21 +144,21 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
 // CHECK-LABEL: func.func @tile_winograd_input(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_4:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_7:.*]] = arith.constant 1 : index
 // CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:   %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK:   %[[C2_1:.*]] = arith.constant 2 : index
-// CHECK:   %[[C1_2:.*]] = arith.constant 1 : index
 // CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
-// CHECK:     %[[C0_3:.*]] = arith.constant 0 : index
-// CHECK:     %[[C2_4:.*]] = arith.constant 2 : index
-// CHECK:     %[[C1_5:.*]] = arith.constant 1 : index
 // CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
-// CHECK:       %[[C0_6:.*]] = arith.constant 0 : index
-// CHECK:       %[[C5:.*]] = arith.constant 5 : index
-// CHECK:       %[[C1_7:.*]] = arith.constant 1 : index
 // CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
 // CHECK:         %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
 // CHECK:         %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
@@ -185,26 +185,26 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (2, -d0 + 5)>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 5, 2)>
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
 // CHECK-LABEL: func.func @tile_winograd_input(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_7:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_5:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_6:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_8:.*]] = arith.constant 2 : index
 // CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
-// CHECK:   %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK:   %[[C2_2:.*]] = arith.constant 2 : index
-// CHECK:   %[[C2_3:.*]] = arith.constant 2 : index
 // CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
-// CHECK:     %[[C0_4:.*]] = arith.constant 0 : index
-// CHECK:     %[[C2_5:.*]] = arith.constant 2 : index
-// CHECK:     %[[C2_6:.*]] = arith.constant 2 : index
 // CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C2_5]] step %[[C2_6]]
-// CHECK:       %[[C0_7:.*]] = arith.constant 0 : index
-// CHECK:       %[[C5:.*]] = arith.constant 5 : index
-// CHECK:       %[[C2_8:.*]] = arith.constant 2 : index
 // CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_7]] to %[[C5]] step %[[C2_8]]
 // CHECK:         %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG8]])
 // CHECK:         %[[S6:.*]] = affine.apply #[[$MAP1]](%[[ARG2]])
@@ -236,21 +236,21 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
 // CHECK-LABEL: func.func @tile_winograd_input(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x1x10x5xf32>, %[[ARG1:.*]]: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
-// CHECK:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK:   %[[C1_0:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_4:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C1_0:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_7:.*]] = arith.constant 1 : index
 // CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C1]] step %[[C1_0]]
-// CHECK:     %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK:     %[[C2:.*]] = arith.constant 2 : index
-// CHECK:     %[[C1_2:.*]] = arith.constant 1 : index
 // CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_2]]
-// CHECK:       %[[C0_3:.*]] = arith.constant 0 : index
-// CHECK:       %[[C2_4:.*]] = arith.constant 2 : index
-// CHECK:       %[[C1_5:.*]] = arith.constant 1 : index
 // CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
-// CHECK:         %[[C0_6:.*]] = arith.constant 0 : index
-// CHECK:         %[[C5:.*]] = arith.constant 5 : index
-// CHECK:         %[[C1_7:.*]] = arith.constant 1 : index
 // CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
 // CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
 // CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
@@ -280,13 +280,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
 // CHECK-LABEL: func.func @tile_winograd_output(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<6x6x2x2x2x2xf32>, %[[ARG1:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_2:.*]] = arith.constant 1 : index
 // CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK:     %[[C2_1:.*]] = arith.constant 2 : index
-// CHECK:     %[[C1_2:.*]] = arith.constant 1 : index
 // CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
 // CHECK:       %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
@@ -312,27 +312,26 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (2, -d0 + 3)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (2, -d0 + 5)>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 3, 2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (-d0 + 5, 2)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4)>
 // CHECK-LABEL:  func.func @tile_winograd_output(
 // CHECK-SAME:   %[[ARG0:.*]]: tensor<6x6x2x2x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
-// CHECK:    %[[S0:.*]] = tensor.empty() : tensor<3x8x8x5xf32>
-// CHECK:    %[[C0:.*]] = arith.constant 0 : index
-// CHECK:    %[[C2:.*]] = arith.constant 2 : index
-// CHECK:    %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:    %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:    %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:    %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:    %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:    %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C2_3:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C2_5:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C2_7:.*]] = arith.constant 2 : index
 // CHECK:    %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
-// CHECK:      %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK:      %[[C2_2:.*]] = arith.constant 2 : index
-// CHECK:      %[[C2_3:.*]] = arith.constant 2 : index
 // CHECK:      %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
-// CHECK:        %[[C0_4:.*]] = arith.constant 0 : index
-// CHECK:        %[[C3:.*]] = arith.constant 3 : index
-// CHECK:        %[[C2_5:.*]] = arith.constant 2 : index
 // CHECK:        %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]]
-// CHECK:          %[[C0_6:.*]] = arith.constant 0 : index
-// CHECK:          %[[C5:.*]] = arith.constant 5 : index
-// CHECK:          %[[C2_7:.*]] = arith.constant 2 : index
 // CHECK:          %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]]
 // CHECK:            %[[C3_8:.*]] = arith.constant 3 : index
 // CHECK:            %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG6]])
@@ -365,21 +364,21 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
 // CHECK-LABEL: func.func @tile_winograd_output(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<6x1x2x1x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
-// CHECK:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_5:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_6:.*]] = arith.constant 1 : index
 // CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK:     %[[C1_1:.*]] = arith.constant 1 : index
-// CHECK:     %[[C1_2:.*]] = arith.constant 1 : index
 // CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]]
-// CHECK:       %[[C0_3:.*]] = arith.constant 0 : index
-// CHECK:       %[[C3:.*]] = arith.constant 3 : index
-// CHECK:       %[[C1_4:.*]] = arith.constant 1 : index
 // CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]]
-// CHECK:         %[[C0_5:.*]] = arith.constant 0 : index
-// CHECK:         %[[C5:.*]] = arith.constant 5 : index
-// CHECK:         %[[C1_6:.*]] = arith.constant 1 : index
 // CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]]
 // CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32>
 // CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])

>From f7a02dbbf2967e5a97d42990277fbef0932ace53 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 14 Aug 2024 04:12:37 +0100
Subject: [PATCH 7/7] Address Max191's comments

---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  54 ++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 183 +++---
 .../transform-tile-and-winograd-rewrite.mlir  | 272 ++++----
 .../Linalg/transform-tile-winograd.mlir       | 595 ++++++++++++------
 4 files changed, 690 insertions(+), 414 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index aaf243e0157267..e82c654afe86a8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -208,6 +208,18 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
     int64_t getOutputOperandRank() {
       return getOutputOperandType().getRank();
     }
+    int64_t getFilterFDim() {
+      return 0;
+    }
+    int64_t getFilterHDim() {
+      return 1;
+    }
+    int64_t getFilterWDim() {
+      return 2;
+    }
+    int64_t getFilterCDim() {
+      return 3;
+    }
   }];
   let hasVerifier = 1;
 }
@@ -266,6 +278,30 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
     int64_t getOutputOperandRank() {
       return getOutputOperandType().getRank();
     }
+    int64_t getInputNDim() {
+      return 0;
+    }
+    int64_t getInputHDim() {
+      return 1;
+    }
+    int64_t getInputWDim() {
+      return 2;
+    }
+    int64_t getInputCDim() {
+      return 3;
+    }
+    int64_t getOutputTileHDim() {
+      return 2;
+    }
+    int64_t getOutputTileWDim() {
+      return 3;
+    }
+    int64_t getOutputNDim() {
+      return 4;
+    }
+    int64_t getOutputCDim() {
+      return 5;
+    }
   }];
   let hasVerifier = 1;
 }
@@ -324,6 +360,24 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
     int64_t getOutputOperandRank() {
       return getOutputOperandType().getRank();
     }
+    int64_t getValueAlphaHDim() {
+      return 0;
+    }
+    int64_t getValueAlphaWDim() {
+      return 1;
+    }
+    int64_t getValueTileHDim() {
+      return 2;
+    }
+    int64_t getValueTileWDim() {
+      return 3;
+    }
+    int64_t getValueNDim() {
+      return 4;
+    }
+    int64_t getValueFDim() {
+      return 5;
+    }
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ee633a715cc9ec..b30e7e8db393e2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2855,8 +2855,8 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
 LogicalResult WinogradFilterTransformOp::verify() {
   auto filterType = cast<ShapedType>(getFilter().getType());
   ArrayRef<int64_t> filterShape = filterType.getShape();
-  int64_t filterH = filterShape[1];
-  int64_t filterW = filterShape[2];
+  int64_t filterH = filterShape[getFilterHDim()];
+  int64_t filterW = filterShape[getFilterWDim()];
   int64_t r = getR();
   int64_t m = getM();
 
@@ -2870,8 +2870,8 @@ LogicalResult WinogradFilterTransformOp::verify() {
   SmallVector<int64_t> expectedOutputShape;
   expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
   expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
-  expectedOutputShape.push_back(filterShape[3]);
-  expectedOutputShape.push_back(filterShape[0]);
+  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
+  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
 
   auto outputType = cast<ShapedType>(getOutput().getType());
   ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -2887,26 +2887,20 @@ WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
   IntegerAttr zeroAttr = builder.getIndexAttr(0);
   IntegerAttr oneAttr = builder.getIndexAttr(1);
   Value filter = getFilter();
-  SmallVector<Range> loopBounds;
-  Range fRange;
-  fRange.offset = zeroAttr;
-  // 0 is the index of F in the filter.
-  fRange.size = getDimValue(builder, loc, filter, 0);
-  fRange.stride = oneAttr;
-  loopBounds.emplace_back(fRange);
-  Range cRange;
-  cRange.offset = zeroAttr;
-  // 3 is the index of C in the filter.
-  cRange.size = getDimValue(builder, loc, filter, 3);
-  cRange.stride = oneAttr;
-  loopBounds.emplace_back(cRange);
+  int64_t filterRank = getFilterOperandRank();
+  SmallVector<Range> loopBounds(filterRank);
+  for (unsigned dim = 0; dim < filterRank; ++dim) {
+    loopBounds[dim].offset = zeroAttr;
+    loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
+    loopBounds[dim].stride = oneAttr;
+  }
   return loopBounds;
 }
 
 SmallVector<utils::IteratorType>
 WinogradFilterTransformOp::getLoopIteratorTypes() {
   int64_t filterRank = getFilterOperandRank();
-  SmallVector<utils::IteratorType> iteratorTypes(filterRank - 2,
+  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
                                                  utils::IteratorType::parallel);
   return iteratorTypes;
 }
@@ -2918,8 +2912,8 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
   ShapedType filterType = getFilterOperandType();
   ArrayRef<int64_t> filterShape = filterType.getShape();
-  int64_t filterH = filterShape[1];
-  int64_t filterW = filterShape[2];
+  int64_t filterH = filterShape[getFilterHDim()];
+  int64_t filterW = filterShape[getFilterWDim()];
   int64_t m = getM();
   int64_t r = getR();
   int64_t alpha = m + r - 1;
@@ -2928,8 +2922,10 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
   IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
   IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
-  resultOffsets.append({zeroAttr, zeroAttr, offsets[1], offsets[0]});
-  resultSizes.append({alphaHAttr, alphaWAttr, sizes[1], sizes[0]});
+  resultOffsets.append(
+      {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
+  resultSizes.append(
+      {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
 
   return success();
 }
@@ -2938,8 +2934,8 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
 /// The input of winograd_filter_transform is (F, KH, KW, C).
 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
 /// Users can specify the tile sizes of F and C.
-/// `offsets` are the values for the offsets of F, C for one tile.
-/// `sizes` are the values for the sizes of F, C for one tile.
+/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
+/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
 FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
     OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes) {
@@ -2947,15 +2943,17 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
   ShapedType filterType = getFilterOperandType();
   ArrayRef<int64_t> filterShape = filterType.getShape();
-  int64_t filterH = filterShape[1];
-  int64_t filterW = filterShape[2];
+  int64_t filterH = filterShape[getFilterHDim()];
+  int64_t filterW = filterShape[getFilterWDim()];
   IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
   IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
   SmallVector<Value> tiledOperands;
   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
-  sliceOffsets.append({offsets[0], zeroAttr, zeroAttr, offsets[1]});
-  sliceSizes.append({sizes[0], filterHAttr, filterWAttr, sizes[1]});
+  sliceOffsets.append(
+      {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
+  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
+                     sizes[getFilterCDim()]});
   int64_t filterRank = getFilterOperandRank();
   SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
   Location loc = getLoc();
@@ -2987,8 +2985,8 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
 LogicalResult WinogradInputTransformOp::verify() {
   auto inputType = cast<ShapedType>(getInput().getType());
   ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
+  int64_t inputH = inputShape[getInputHDim()];
+  int64_t inputW = inputShape[getInputWDim()];
   int m = getM();
   int r = getR();
   int64_t tileSize = m + r - 1;
@@ -3010,8 +3008,8 @@ LogicalResult WinogradInputTransformOp::verify() {
     expectedOutputShape[1] = rightTransform ? tileSize : 1;
     expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1;
   }
-  expectedOutputShape[4] = inputShape[0];
-  expectedOutputShape[5] = inputShape[3];
+  expectedOutputShape[4] = inputShape[getInputNDim()];
+  expectedOutputShape[5] = inputShape[getInputCDim()];
 
   auto outputType = cast<ShapedType>(getOutput().getType());
   ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -3028,11 +3026,12 @@ WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
   IntegerAttr oneAttr = builder.getIndexAttr(1);
   Value output = getOutput();
   int64_t outputRank = getOutputOperandRank();
-  SmallVector<Range> loopBounds(outputRank - 2);
-  for (unsigned dim = 2; dim < outputRank; ++dim) {
-    loopBounds[dim - 2].offset = zeroAttr;
-    loopBounds[dim - 2].size = getDimValue(builder, loc, output, dim);
-    loopBounds[dim - 2].stride = oneAttr;
+  SmallVector<Range> loopBounds(outputRank);
+  for (unsigned dim = 0; dim < outputRank; ++dim) {
+    loopBounds[dim].offset = zeroAttr;
+    // alphaH, alphaW, tileH, tileW, N, C
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = oneAttr;
   }
   return loopBounds;
 }
@@ -3040,7 +3039,7 @@ WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
 SmallVector<utils::IteratorType>
 WinogradInputTransformOp::getLoopIteratorTypes() {
   int64_t outputRank = getOutputOperandRank();
-  SmallVector<utils::IteratorType> iteratorTypes(outputRank - 2,
+  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
                                                  utils::IteratorType::parallel);
   return iteratorTypes;
 }
@@ -3052,8 +3051,8 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
   ShapedType inputType = getInputOperandType();
   ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
+  int64_t inputH = inputShape[getInputHDim()];
+  int64_t inputW = inputShape[getInputWDim()];
   int64_t m = getM();
   int64_t r = getR();
   int64_t alpha = m + r - 1;
@@ -3062,10 +3061,12 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
   IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
   IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
-  resultOffsets.append(
-      {zeroAttr, zeroAttr, offsets[0], offsets[1], offsets[2], offsets[3]});
-  resultSizes.append(
-      {alphaHAttr, alphaWAttr, sizes[0], sizes[1], sizes[2], sizes[3]});
+  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
+                        offsets[getOutputTileWDim()], offsets[getOutputNDim()],
+                        offsets[getOutputCDim()]});
+  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
+                      sizes[getOutputTileWDim()], sizes[getOutputNDim()],
+                      sizes[getOutputCDim()]});
 
   return success();
 }
@@ -3084,8 +3085,8 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
   ShapedType inputType = getInputOperandType();
   ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
+  int64_t inputH = inputShape[getInputHDim()];
+  int64_t inputW = inputShape[getInputWDim()];
   int64_t m = getM();
   int64_t r = getR();
 
@@ -3093,20 +3094,16 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   MLIRContext *context = builder.getContext();
   auto offsetAffineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-  Value mappedOffsetH = builder.create<affine::AffineApplyOp>(
-      loc, offsetAffineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[0]));
-  Value mappedOffsetW = builder.create<affine::AffineApplyOp>(
-      loc, offsetAffineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[1]));
+  Value mappedOffsetH = affine::makeComposedAffineApply(
+      builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
+  Value mappedOffsetW = affine::makeComposedAffineApply(
+      builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
   auto sizeAffineMap = AffineMap::get(
       1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
-  Value mappedSizeH = builder.create<affine::AffineApplyOp>(
-      loc, sizeAffineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, sizes[0]));
-  Value mappedSizeW = builder.create<affine::AffineApplyOp>(
-      loc, sizeAffineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, sizes[1]));
+  Value mappedSizeH = affine::makeComposedAffineApply(
+      builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
+  Value mappedSizeW = affine::makeComposedAffineApply(
+      builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
 
   SmallVector<Value> tiledOperands;
   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
@@ -3115,12 +3112,14 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
       inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
   OpFoldResult offsetW =
       inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
-  sliceOffsets.append({offsets[2], offsetH, offsetW, offsets[3]});
+  sliceOffsets.append(
+      {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
   OpFoldResult sizeH =
       inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
   OpFoldResult sizeW =
       inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
-  sliceSizes.append({sizes[2], sizeH, sizeW, sizes[3]});
+  sliceSizes.append(
+      {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
   int64_t inputRank = getInputOperandRank();
   SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
@@ -3151,10 +3150,10 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
 LogicalResult WinogradOutputTransformOp::verify() {
   auto valueType = cast<ShapedType>(getValue().getType());
   ArrayRef<int64_t> valueShape = valueType.getShape();
-  int64_t valueH = valueShape[0];
-  int64_t valueW = valueShape[1];
-  int64_t valueTileH = valueShape[2];
-  int64_t valueTileW = valueShape[3];
+  int64_t valueH = valueShape[getValueAlphaHDim()];
+  int64_t valueW = valueShape[getValueAlphaWDim()];
+  int64_t valueTileH = valueShape[getValueTileHDim()];
+  int64_t valueTileW = valueShape[getValueTileWDim()];
   int m = getM();
   int r = getR();
   bool leftTransform = valueH != 1;
@@ -3176,8 +3175,8 @@ LogicalResult WinogradOutputTransformOp::verify() {
       return emitOpError("expect input width equals to input tile size");
     expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW;
   }
-  expectedOutputShape[0] = valueShape[4];
-  expectedOutputShape[3] = valueShape[5];
+  expectedOutputShape[0] = valueShape[getValueNDim()];
+  expectedOutputShape[3] = valueShape[getValueFDim()];
 
   auto outputType = cast<ShapedType>(getOutput().getType());
   ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -3194,11 +3193,12 @@ WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
   IntegerAttr oneAttr = builder.getIndexAttr(1);
   Value value = getValue();
   int64_t valueRank = getValueOperandRank();
-  SmallVector<Range> loopBounds(valueRank - 2);
-  for (unsigned dim = 2; dim < valueRank; ++dim) {
-    loopBounds[dim - 2].offset = zeroAttr;
-    loopBounds[dim - 2].size = getDimValue(builder, loc, value, dim);
-    loopBounds[dim - 2].stride = oneAttr;
+  SmallVector<Range> loopBounds(valueRank);
+  for (unsigned dim = 0; dim < valueRank; ++dim) {
+    loopBounds[dim].offset = zeroAttr;
+    // alphaH, alphaW, tileH, tileW, N, F
+    loopBounds[dim].size = getDimValue(builder, loc, value, dim);
+    loopBounds[dim].stride = oneAttr;
   }
   return loopBounds;
 }
@@ -3206,7 +3206,7 @@ WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
 SmallVector<utils::IteratorType>
 WinogradOutputTransformOp::getLoopIteratorTypes() {
   int64_t valueRank = getValueOperandRank();
-  SmallVector<utils::IteratorType> iteratorTypes(valueRank - 2,
+  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
                                                  utils::IteratorType::parallel);
   return iteratorTypes;
 }
@@ -3221,16 +3221,15 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
   MLIRContext *context = builder.getContext();
   auto affineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-  Value mappedOffsetH = builder.create<affine::AffineApplyOp>(
-      loc, affineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[0]));
-  Value mappedOffsetW = builder.create<affine::AffineApplyOp>(
-      loc, affineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[1]));
-  Value mappedSizeH = builder.create<affine::AffineApplyOp>(
-      loc, affineMap, getValueOrCreateConstantIndexOp(builder, loc, sizes[0]));
-  Value mappedSizeW = builder.create<affine::AffineApplyOp>(
-      loc, affineMap, getValueOrCreateConstantIndexOp(builder, loc, sizes[1]));
+
+  Value mappedOffsetH = affine::makeComposedAffineApply(
+      builder, loc, affineMap, offsets[getValueTileHDim()]);
+  Value mappedOffsetW = affine::makeComposedAffineApply(
+      builder, loc, affineMap, offsets[getValueTileWDim()]);
+  Value mappedSizeH = affine::makeComposedAffineApply(
+      builder, loc, affineMap, sizes[getValueTileHDim()]);
+  Value mappedSizeW = affine::makeComposedAffineApply(
+      builder, loc, affineMap, sizes[getValueTileWDim()]);
 
   ShapedType valueType = getValueOperandType();
   ArrayRef<int64_t> valueShape = valueType.getShape();
@@ -3247,8 +3246,10 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
   OpFoldResult sizeW =
       valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
 
-  resultOffsets.append({offsets[2], offsetH, offsetW, offsets[3]});
-  resultSizes.append({sizes[2], sizeH, sizeW, sizes[3]});
+  resultOffsets.append(
+      {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
+  resultSizes.append(
+      {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
   return success();
 }
 
@@ -3269,15 +3270,17 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
 
   ShapedType valueType = getValueOperandType();
   ArrayRef<int64_t> valueShape = valueType.getShape();
-  int64_t alphaH = valueShape[0];
-  int64_t alphaW = valueShape[1];
+  int64_t alphaH = valueShape[getValueAlphaHDim()];
+  int64_t alphaW = valueShape[getValueAlphaWDim()];
   IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
   IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
-  sliceOffsets.append(
-      {zeroAttr, zeroAttr, offsets[0], offsets[1], offsets[2], offsets[3]});
-  sliceSizes.append(
-      {alphaHAttr, alphaWAttr, sizes[0], sizes[1], sizes[2], sizes[3]});
+  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
+                       offsets[getValueTileWDim()], offsets[getValueNDim()],
+                       offsets[getValueFDim()]});
+  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
+                     sizes[getValueTileWDim()], sizes[getValueNDim()],
+                     sizes[getValueFDim()]});
   int64_t valueRank = getValueOperandRank();
   SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 54eb755047f18b..c86b33ba88ee36 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -18,9 +18,9 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %3, %loop3:4 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %5, %loop5:4 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
     %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
     %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
@@ -36,6 +36,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: func.func @conv2d
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
 // CHECK:  %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK:  %[[C6:.*]] = arith.constant 6 : index
 // CHECK:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK:  %[[C5:.*]] = arith.constant 5 : index
 // CHECK:  %[[C2:.*]] = arith.constant 2 : index
@@ -51,53 +52,49 @@ module attributes {transform.with_named_sequence} {
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
 // CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
-// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
-// CHECK:        %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK:          %[[S15:.*]] = linalg.matmul
-// CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S17]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
-// CHECK:        scf.yield %[[S13]]
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK:      scf.yield %[[INSERTED_SLICE]]
-// CHECK:    scf.yield %[[S9]]
-// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK:      %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK:        %[[S11:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK:          %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG7]])
+// CHECK:          %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG9]])
+// CHECK:          %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S12]], %[[S13]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
+// CHECK:          %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG7]], %[[ARG9]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK:          %[[S14:.*]] = scf.for %[[ARG11:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG12:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK:            %[[S15:.*]] = scf.for %[[ARG13:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG14:.*]] = %[[ARG12]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK:              %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG11]], 0, 0, %[[ARG13]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<6x6xf32>
+// CHECK:              %[[S17:.*]] = linalg.matmul
+// CHECK:              %[[S19:.*]] = linalg.matmul
+// CHECK:              %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S19]] into %[[ARG14]][0, 0, 0, 0, %[[ARG11]], %[[ARG13]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x2x5xf32>
+// CHECK:          %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG10]][0, 0, %[[ARG7]], %[[ARG9]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
+// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
 // CHECK:  %[[S6:.*]] = linalg.batch_matmul
-// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
 // CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]])
-// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
-// CHECK:        %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[S19:.*]] = linalg.matmul
-// CHECK:          %[[S20:.*]] = tensor.empty()
-// CHECK:          %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
-// CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK:            linalg.yield %[[IN]] : f32
-// CHECK:          } -> tensor<4x4xf32>
-// CHECK:          %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
-// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
-// CHECK:        scf.yield %[[S15]]
-// CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK:      scf.yield %[[INSERTED_SLICE]]
-// CHECK:    scf.yield %[[S9]]
+// CHECK:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK:      %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK:        %[[S11:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK:          %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG7]], %[[ARG9]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK:          %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG7]])
+// CHECK:          %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG9]])
+// CHECK:          %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S12]], %[[S13]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
+// CHECK:          %[[S14:.*]] = scf.for %[[ARG11:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG12:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK:            %[[S17:.*]] = scf.for %[[ARG13:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG14:.*]] = %[[ARG12]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK:              %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG11]], %[[ARG13]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6xf32>
+// CHECK:              %[[S19:.*]] = linalg.matmul
+// CHECK:              %[[S21:.*]] = linalg.matmul
+// CHECK:              %[[S22:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK:              %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S22]] : tensor<4x4xf32>) {
+// CHECK:              ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:                linalg.yield %[[IN]] : f32
+// CHECK:              } -> tensor<4x4xf32>
+// CHECK:              %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:              %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S24]] into %[[ARG14]][%[[ARG11]], 0, 0, %[[ARG13]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x4x4x2xf32>
+// CHECK:          %[[S15:.*]] = affine.apply #[[$MAP0]](%[[ARG7]])
+// CHECK:          %[[S16:.*]] = affine.apply #[[$MAP0]](%[[ARG9]])
+// CHECK:          %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG10]][0, %[[S15]], %[[S16]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
 
 // -----
 
@@ -129,9 +126,9 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %3, %loop3:4 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %5, %loop5:4 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
     %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
     %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
@@ -148,10 +145,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
 // CHECK:  %[[CST:.*]] = arith.constant 1.024000e+03 : f32
 // CHECK:  %[[C3:.*]] = arith.constant 3 : index
+// CHECK:  %[[C6:.*]] = arith.constant 6 : index
 // CHECK:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK:  %[[C5:.*]] = arith.constant 5 : index
 // CHECK:  %[[C2:.*]] = arith.constant 2 : index
 // CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:  %[[S0:.*]] = tensor.empty()
 // CHECK:  %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
@@ -164,56 +163,55 @@ module attributes {transform.with_named_sequence} {
 // CHECK:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
 // CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
-// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
-// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK:        %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
-// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK:          %[[S15:.*]] = linalg.matmul
-// CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S17]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
-// CHECK:        scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK:      scf.yield %[[INSERTED_SLICE]]
-// CHECK:    scf.yield %[[S9]]
-// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK:      %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK:        %[[S11:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK:          %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG7]])
+// CHECK:          %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG9]])
+// CHECK:          %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S12]], %[[S13]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32>
+// CHECK:          %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG7]], %[[ARG9]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK:          %[[S14:.*]] = scf.for %[[ARG11:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG12:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK:            %[[S15:.*]] = scf.for %[[ARG13:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG14:.*]] = %[[ARG12]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK:              %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG11]], 0, 0, %[[ARG13]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<6x6xf32>
+// CHECK:              %[[S17:.*]] = linalg.matmul
+// CHECK:              %[[S19:.*]] = linalg.matmul
+// CHECK:              %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S19]] into %[[ARG14]][0, 0, 0, 0, %[[ARG11]], %[[ARG13]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x2x5xf32>
+// CHECK:          %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG10]][0, 0, %[[ARG7]], %[[ARG9]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x3x3x2x5xf32>
+// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
 // CHECK:  %[[S6:.*]] = linalg.batch_matmul
-// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
-// CHECK:  %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK:  %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK:  ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
+// CHECK:    tensor.yield %[[CST_6]] : f32
+// CHECK:  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
 // CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK:  %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
-// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK:        %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
-// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[S19:.*]] = linalg.matmul
-// CHECK:          %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK:          %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
-// CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK:            linalg.yield %[[IN]] : f32
-// CHECK:          } -> tensor<4x4xf32>
-// CHECK:          %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S22]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
-// CHECK:          scf.yield %[[INSERTED_SLICE_12]]
-// CHECK:        scf.yield %[[S15]] : tensor<2x4x4x2xf32>
-// CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK:      scf.yield %[[INSERTED_SLICE]]
-// CHECK:    scf.yield %[[S9]]
-// CHECK:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
-// CHECK:  return %[[EXTRACTED_SLICE]]
+// CHECK:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK:      %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK:        %[[S11:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK:          %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG7]], %[[ARG9]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK:          %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG7]])
+// CHECK:          %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG9]])
+// CHECK:          %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S12]], %[[S13]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32>
+// CHECK:          %[[S14:.*]] = scf.for %[[ARG11:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG12:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK:            %[[S17:.*]] = scf.for %[[ARG13:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG14:.*]] = %[[ARG12]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK:              %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG11]], %[[ARG13]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6xf32>
+// CHECK:              %[[S19:.*]] = linalg.matmul
+// CHECK:              %[[S21:.*]] = linalg.matmul
+// CHECK:              %[[S22:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK:              %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S22]] : tensor<4x4xf32>) {
+// CHECK:              ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:                linalg.yield %[[IN]] : f32
+// CHECK:              } -> tensor<4x4xf32>
+// CHECK:              %[[S24:.*]] = linalg.mul
+// CHECK:              %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S24]] into %[[ARG14]][%[[ARG11]], 0, 0, %[[ARG13]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x4x4x2xf32>
+// CHECK:          %[[S15:.*]] = affine.apply #[[$MAP0]](%[[ARG7]])
+// CHECK:          %[[S16:.*]] = affine.apply #[[$MAP0]](%[[ARG9]])
+// CHECK:          %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG10]][0, %[[S15]], %[[S16]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32>
+// CHECK:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
 
 // -----
 
@@ -235,9 +233,9 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %3, %loop3:4 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %5, %loop5:4 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
     %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
     %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
@@ -251,42 +249,40 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func.func @conv2d_mx1_rx1
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK:   %[[CST:.*]] = arith.constant 3.200000e+01 : f32
-// CHECK:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK:   %[[C5:.*]] = arith.constant 5 : index
-// CHECK:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
-// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
-// CHECK:       %[[S9:.*]] = linalg.matmul
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
-// CHECK:       scf.yield %[[INSERTED_SLICE]]
-// CHECK:     scf.yield %[[S7]]
-// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
-// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
-// CHECK:       %[[S9:.*]] = linalg.matmul
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:       scf.yield %[[INSERTED_SLICE]]
-// CHECK:     scf.yield %[[S7]]
-// CHECK:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK:   %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
-// CHECK:   %[[S5:.*]] = linalg.batch_matmul
-// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
-// CHECK:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
-// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:       %[[S9:.*]] = linalg.matmul
-// CHECK:       %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
-// CHECK:       %[[S11:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S10]] : tensor<4x1xf32>) {
-// CHECK:       ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK:         linalg.yield %[[IN]] : f32
-// CHECK:       } -> tensor<4x1xf32>
-// CHECK:       %[[S12:.*]] = linalg.mul ins(%[[S11]], %[[S9]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
-// CHECK:       scf.yield %[[INSERTED_SLICE]]
-// CHECK:     scf.yield %[[S7]]
-// CHECK:   return %[[S6]]
+// CHECK:  %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK:  %[[C6:.*]] = arith.constant 6 : index
+// CHECK:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK:  %[[C5:.*]] = arith.constant 5 : index
+// CHECK:  %[[C2:.*]] = arith.constant 2 : index
+// CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<3x1xf32>
+// CHECK:      %[[S11:.*]] = linalg.matmul
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x5x2xf32>
+// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[S2]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK:      %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK:        %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<2x6x1x5xf32> to tensor<6x1xf32>
+// CHECK:        %[[S12:.*]] = linalg.matmul
+// CHECK:        %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG8]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1x2x5xf32>
+// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK:  %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK:  %[[S6:.*]] = linalg.batch_matmul
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x4x1x2xf32>
+// CHECK:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG2]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK:      %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK:        %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x2x2xf32> to tensor<6x1xf32>
+// CHECK:        %[[S12:.*]] = linalg.matmul
+// CHECK:        %[[S13:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK:        %[[S14:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S13]] : tensor<4x1xf32>) {
+// CHECK:        ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:          linalg.yield %[[IN]] : f32
+// CHECK:        } -> tensor<4x1xf32>
+// CHECK:        %[[S15:.*]] = linalg.mul ins(%[[S14]], %[[S12]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S13]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK:        %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG8]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<4x1xf32> into tensor<2x4x1x2xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
index 2cad4fa37530d6..ea5d4e99c87311 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -8,24 +8,43 @@ func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
 
 // CHECK-LABEL: func.func @tile_winograd_filter(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
-// CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:  %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:  %[[C1_1:.*]] = arith.constant 1 : index
-// CHECK:  %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:    %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
-// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
-// CHECK:      %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
-// CHECK:      %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C3_3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_6:.*]] = arith.constant 1 : index
+// CHECK:       %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C3]] step %[[C1_4]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_1]] to %[[C3_3]] step %[[C1_5]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_2]] to %[[C5]] step %[[C1_6]] iter_args(%[[ARG9:.*]] = %[[ARG7]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG8]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG8]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[S5:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_7]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S5]] into %[[ARG9]][0, 0, %[[ARG8]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S4]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S3]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S2]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S1]] : tensor<6x6x5x2xf32>
 
 // -----
 
@@ -37,27 +56,47 @@ func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 5, 2)>
+
 // CHECK-LABEL: func.func @tile_winograd_filter(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C3_3:.*]] = arith.constant 3 : index
 // CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C2_1:.*]] = arith.constant 2 : index
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C2_1]]
-// CHECK:       %[[C5_2:.*]] = arith.constant 5 : index
-// CHECK:       %[[S3:.*]] = affine.min #[[$MAP0]](%[[ARG4]])
-// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, %[[S3]]] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x?xf32>
-// CHECK:       %[[EXTRACTED_SLICE_3:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32>
-// CHECK:       %[[S4:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
+// CHECK-DAG:   %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2_6:.*]] = arith.constant 2 : index
+// CHECK:       %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C3]] step %[[C1_4]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_1]] to %[[C3_3]] step %[[C1_5]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_2]] to %[[C5]] step %[[C2_6]] iter_args(%[[ARG9:.*]] = %[[ARG7]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:           %[[C5_7:.*]] = arith.constant 5 : index
+// CHECK-NEXT:           %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG8]])
+// CHECK-NEXT:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG8]]] [1, 3, 3, %[[S5]]] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x?xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG8]], %[[ARG2]]] [6, 6, %[[S5]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32>
+// CHECK-NEXT:           %[[S6:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_8]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S6]] into %[[ARG9]][0, 0, %[[ARG8]], %[[ARG2]]] [6, 6, %[[S5]], 1] [1, 1, 1, 1] : tensor<6x6x?x1xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S4]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S3]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S2]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S1]] : tensor<6x6x5x2xf32>
 
 // -----
 
@@ -69,24 +108,43 @@ func.func @tile_winograd_filter(%arg0: tensor<2x3x1x5xf32>, %arg1: tensor<6x1x5x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
 
 // CHECK-LABEL: func.func @tile_winograd_filter(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x1x5xf32>, %[[ARG1:.*]]: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C1_1:.*]] = arith.constant 1 : index
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
-// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
-// CHECK:       %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
-// CHECK:       %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C1_3:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_6:.*]] = arith.constant 1 : index
+// CHECK:       %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1_3]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK-NEXT:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C3]] step %[[C1_4]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK-NEXT:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_1]] to %[[C1]] step %[[C1_5]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK-NEXT:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_2]] to %[[C5]] step %[[C1_6]] iter_args(%[[ARG9:.*]] = %[[ARG7]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG8]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG8]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
+// CHECK-NEXT:           %[[S5:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_7]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S5]] into %[[ARG9]][0, 0, %[[ARG8]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<6x1x5x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S4]] : tensor<6x1x5x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S3]] : tensor<6x1x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S2]] : tensor<6x1x5x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S1]] : tensor<6x1x5x2xf32>
 
 // -----
 
@@ -98,32 +156,50 @@ func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loop3:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<() -> (6)>
+
 // CHECK-LABEL: func.func @tile_winograd_input(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
-// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
-// CHECK:   %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
-// CHECK:   %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:   %[[C1_3:.*]] = arith.constant 1 : index
-// CHECK:   %[[S5:.*]] = affine.apply #[[$MAP1]](%[[C1_3]])
-// CHECK:   %[[C1_4:.*]] = arith.constant 1 : index
-// CHECK:   %[[S6:.*]] = affine.apply #[[$MAP1]](%[[C1_4]])
-// CHECK:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x5xf32>
-// CHECK:   %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK:   %[[S7:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C6_3:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_4:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_6:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK:       %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C6_3]] step %[[C1_5]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_6]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_2]] to %[[C2_4]] step %[[C1_7]] iter_args(%[[ARG9:.*]] = %[[ARG7]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK-NEXT:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG8]])
+// CHECK-NEXT:           %[[S7:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:           %[[S8:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S5]], %[[S6]], 0] [2, %[[S7]], %[[S8]], 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x5xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG6]], %[[ARG8]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:           %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_8]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, 0, %[[ARG6]], %[[ARG8]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S4]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S3]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S2]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S1]] : tensor<6x6x2x2x2x5xf32>
 
 // -----
 
@@ -135,40 +211,62 @@ func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %loop3:6 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<() -> (6)>
+
 // CHECK-LABEL: func.func @tile_winograd_input(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C0_3:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C0_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C2_4:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C1_5:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C1_7:.*]] = arith.constant 1 : index
-// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
-// CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
-// CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
-// CHECK:         %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
-// CHECK:         %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:         %[[C1_8:.*]] = arith.constant 1 : index
-// CHECK:         %[[S7:.*]] = affine.apply #[[$MAP1]](%[[C1_8]])
-// CHECK:         %[[C1_9:.*]] = arith.constant 1 : index
-// CHECK:         %[[S8:.*]] = affine.apply #[[$MAP1]](%[[C1_9]])
-// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], %[[S8]], 1] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<1x?x?x1xf32>
-// CHECK:         %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32>
-// CHECK:         %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C6_5:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_6:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_7:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_8:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_9:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_10:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_11:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_12:.*]] = arith.constant 1 : index
+// CHECK:       %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C6_5]] step %[[C1_8]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_9]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_2]] to %[[C2_6]] step %[[C1_10]] iter_args(%[[ARG9:.*]] = %[[ARG7]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:           %[[S5:.*]] = scf.for %[[ARG10:.*]] = %[[C0_3]] to %[[C2_7]] step %[[C1_11]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:             %[[S6:.*]] = scf.for %[[ARG12:.*]] = %[[C0_4]] to %[[C5]] step %[[C1_12]] iter_args(%[[ARG13:.*]] = %[[ARG11]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:               %[[S7:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK-NEXT:               %[[S8:.*]] = affine.apply #[[$MAP0]](%[[ARG8]])
+// CHECK-NEXT:               %[[S9:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:               %[[S10:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:               %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG10]], %[[S7]], %[[S8]], %[[ARG12]]] [1, %[[S9]], %[[S10]], 1] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<1x?x?x1xf32>
+// CHECK-NEXT:               %[[EXTRACTED_SLICE_13:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG6]], %[[ARG8]], %[[ARG10]], %[[ARG12]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32>
+// CHECK-NEXT:               %[[S11:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_13]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+// CHECK-NEXT:               %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[ARG13]][0, 0, %[[ARG6]], %[[ARG8]], %[[ARG10]], %[[ARG12]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:               scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[S6]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           scf.yield %[[S5]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S4]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S3]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S2]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S1]] : tensor<6x6x2x2x2x5xf32>
 
 // -----
 
@@ -180,42 +278,65 @@ func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %loop3:6 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 5, 2)>
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK: #[[$MAP2:.+]] = affine_map<() -> (10)>
+
 // CHECK-LABEL: func.func @tile_winograd_input(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C0_7:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C2_5:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[C2_0:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C2_6:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C2_8:.*]] = arith.constant 2 : index
-// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
-// CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
-// CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C2_5]] step %[[C2_6]]
-// CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_7]] to %[[C5]] step %[[C2_8]]
-// CHECK:         %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG8]])
-// CHECK:         %[[S6:.*]] = affine.apply #[[$MAP1]](%[[ARG2]])
-// CHECK:         %[[S7:.*]] = affine.apply #[[$MAP1]](%[[ARG4]])
-// CHECK:         %[[C2_10:.*]] = arith.constant 2 : index
-// CHECK:         %[[S8:.*]] = affine.apply #[[$MAP2]](%[[C2_10]])
-// CHECK:         %[[C2_11:.*]] = arith.constant 2 : index
-// CHECK:         %[[S9:.*]] = affine.apply #[[$MAP2]](%[[C2_11]])
-// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S6]], %[[S7]], %[[ARG8]]] [2, %[[S8]], %[[S9]], %[[S5]]] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x?xf32>
-// CHECK:         %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32>
-// CHECK:         %[[S10:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C6_5:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_6:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_7:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_8:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2_9:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_10:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_11:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_12:.*]] = arith.constant 2 : index
+// CHECK:       %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C6_5]] step %[[C1_8]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_1]] to %[[C2]] step %[[C2_9]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_2]] to %[[C2_6]] step %[[C2_10]] iter_args(%[[ARG9:.*]] = %[[ARG7]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:           %[[S5:.*]] = scf.for %[[ARG10:.*]] = %[[C0_3]] to %[[C2_7]] step %[[C2_11]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:             %[[S6:.*]] = scf.for %[[ARG12:.*]] = %[[C0_4]] to %[[C5]] step %[[C2_12]] iter_args(%[[ARG13:.*]] = %[[ARG11]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:               %[[C5_13:.*]] = arith.constant 5 : index
+// CHECK-NEXT:               %[[S7:.*]] = affine.min #[[$MAP0]](%[[ARG12]])
+// CHECK-NEXT:               %[[S8:.*]] = affine.apply #[[$MAP1]](%[[ARG6]])
+// CHECK-NEXT:               %[[S9:.*]] = affine.apply #[[$MAP1]](%[[ARG8]])
+// CHECK-NEXT:               %[[S10:.*]] = affine.apply #[[$MAP2]]()
+// CHECK-NEXT:               %[[S11:.*]] = affine.apply #[[$MAP2]]()
+// CHECK-NEXT:               %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG10]], %[[S8]], %[[S9]], %[[ARG12]]] [2, %[[S10]], %[[S11]], %[[S7]]] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x?xf32>
+// CHECK-NEXT:               %[[EXTRACTED_SLICE_14:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG6]], %[[ARG8]], %[[ARG10]], %[[ARG12]]] [6, 6, 2, 2, 2, %[[S7]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32>
+// CHECK-NEXT:               %[[S12:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_14]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
+// CHECK-NEXT:               %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG13]][0, 0, %[[ARG6]], %[[ARG8]], %[[ARG10]], %[[ARG12]]] [6, 6, 2, 2, 2, %[[S7]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x?xf32> into tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:               scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[S6]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           scf.yield %[[S5]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S4]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S3]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S2]] : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S1]] : tensor<6x6x2x2x2x5xf32>
 
 // -----
 
@@ -227,40 +348,62 @@ func.func @tile_winograd_input(%arg0: tensor<2x1x10x5xf32>, %arg1: tensor<1x6x1x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %loop3:6 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<() -> (6)>
+
 // CHECK-LABEL: func.func @tile_winograd_input(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x1x10x5xf32>, %[[ARG1:.*]]: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<1x6x1x2x2x5xf32>
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C0_3:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C1_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:   %[[C2_4:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_6:.*]] = arith.constant 2 : index
 // CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG:   %[[C1_0:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C1_2:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C1_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C1_7:.*]] = arith.constant 1 : index
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C1]] step %[[C1_0]]
-// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_2]]
-// CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
-// CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
-// CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
-// CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:           %[[C1_8:.*]] = arith.constant 1 : index
-// CHECK:           %[[S7:.*]] = affine.apply #[[$MAP1]](%[[C1_8]])
-// CHECK:           %[[C1_9:.*]] = arith.constant 1 : index
-// CHECK:           %[[S8:.*]] = affine.apply #[[$MAP1]](%[[C1_9]])
-// CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], 0, %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
-// CHECK:           %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
-// CHECK:           %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
+// CHECK-DAG:   %[[C1_8:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_9:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_10:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_11:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_12:.*]] = arith.constant 1 : index
+// CHECK:       %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C1]] step %[[C1_7]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<1x6x1x2x2x5xf32>) {
+// CHECK-NEXT:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C6]] step %[[C1_8]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<1x6x1x2x2x5xf32>) {
+// CHECK-NEXT:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_1]] to %[[C1_5]] step %[[C1_9]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<1x6x1x2x2x5xf32>) {
+// CHECK-NEXT:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_2]] to %[[C2]] step %[[C1_10]] iter_args(%[[ARG9:.*]] = %[[ARG7]]) -> (tensor<1x6x1x2x2x5xf32>) {
+// CHECK-NEXT:           %[[S5:.*]] = scf.for %[[ARG10:.*]] = %[[C0_3]] to %[[C2_6]] step %[[C1_11]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<1x6x1x2x2x5xf32>) {
+// CHECK-NEXT:             %[[S6:.*]] = scf.for %[[ARG12:.*]] = %[[C0_4]] to %[[C5]] step %[[C1_12]] iter_args(%[[ARG13:.*]] = %[[ARG11]]) -> (tensor<1x6x1x2x2x5xf32>) {
+// CHECK-NEXT:               %[[S7:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK-NEXT:               %[[S8:.*]] = affine.apply #[[$MAP0]](%[[ARG8]])
+// CHECK-NEXT:               %[[S9:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:               %[[S10:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:               %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG10]], 0, %[[S8]], %[[ARG12]]] [1, 1, %[[S10]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
+// CHECK-NEXT:               %[[EXTRACTED_SLICE_13:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG6]], %[[ARG8]], %[[ARG10]], %[[ARG12]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
+// CHECK-NEXT:               %[[S11:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_13]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
+// CHECK-NEXT:               %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[ARG13]][0, 0, %[[ARG6]], %[[ARG8]], %[[ARG10]], %[[ARG12]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x1x1x1xf32> into tensor<1x6x1x2x2x5xf32>
+// CHECK-NEXT:               scf.yield %[[INSERTED_SLICE]] : tensor<1x6x1x2x2x5xf32>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[S6]] : tensor<1x6x1x2x2x5xf32>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           scf.yield %[[S5]] : tensor<1x6x1x2x2x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S4]] : tensor<1x6x1x2x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S3]] : tensor<1x6x1x2x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S2]] : tensor<1x6x1x2x2x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S1]] : tensor<1x6x1x2x2x5xf32>
 
 // -----
 
@@ -272,30 +415,54 @@ func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x2x2xf32>, %arg1: tensor<2
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<() -> (4)>
+
 // CHECK-LABEL: func.func @tile_winograd_output(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<6x6x2x2x2x2xf32>, %[[ARG1:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C6_3:.*]] = arith.constant 6 : index
 // CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:   %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_4:.*]] = arith.constant 2 : index
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C1_2:.*]] = arith.constant 1 : index
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
-// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
-// CHECK:       %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
-// CHECK:       %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:       %[[C1_3:.*]] = arith.constant 1 : index
-// CHECK:       %[[S5:.*]] = affine.apply #[[$MAP0]](%[[C1_3]])
-// CHECK:       %[[C1_4:.*]] = arith.constant 1 : index
-// CHECK:       %[[S6:.*]] = affine.apply #[[$MAP0]](%[[C1_4]])
-// CHECK:       %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32>
+// CHECK-DAG:   %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_6:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK:       %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C6_3]] step %[[C1_5]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_6]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_2]] to %[[C2_4]] step %[[C1_7]] iter_args(%[[ARG9:.*]] = %[[ARG7]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG6]], %[[ARG8]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK-NEXT:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG8]])
+// CHECK-NEXT:           %[[S7:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:           %[[S8:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[ARG1]][0, %[[S5]], %[[S6]], 0] [2, %[[S7]], %[[S8]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32>
+// CHECK-NEXT:           %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x6x1x1x2x2xf32>) outs(%[[EXTRACTED_SLICE_8]] : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
+// CHECK-NEXT:           %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK-NEXT:           %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG8]])
+// CHECK-NEXT:           %[[S12:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:           %[[S13:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, %[[S10]], %[[S11]], 0] [2, %[[S12]], %[[S13]], 2] [1, 1, 1, 1] : tensor<2x?x?x2xf32> into tensor<2x8x8x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S4]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S3]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S2]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S1]] : tensor<2x8x8x2xf32>
 
 // -----
 
@@ -307,7 +474,7 @@ func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x3x5xf32>, %arg1: tensor<3
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %loop1:6 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -315,36 +482,65 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 3, 2)>
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (-d0 + 5, 2)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4)>
-// CHECK-LABEL:  func.func @tile_winograd_output(
-// CHECK-SAME:   %[[ARG0:.*]]: tensor<6x6x2x2x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
-// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C0_4:.*]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C0_6:.*]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:    %[[C2_2:.*]] = arith.constant 2 : index
-// CHECK-DAG:    %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG:    %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG:    %[[C2_0:.*]] = arith.constant 2 : index
-// CHECK-DAG:    %[[C2_3:.*]] = arith.constant 2 : index
-// CHECK-DAG:    %[[C2_5:.*]] = arith.constant 2 : index
-// CHECK-DAG:    %[[C2_7:.*]] = arith.constant 2 : index
-// CHECK:    %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
-// CHECK:      %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
-// CHECK:        %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]]
-// CHECK:          %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]]
-// CHECK:            %[[C3_8:.*]] = arith.constant 3 : index
-// CHECK:            %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG6]])
-// CHECK:            %[[C5_9:.*]] = arith.constant 5 : index
-// CHECK:            %[[S6:.*]] = affine.min #[[$MAP1]](%[[ARG8]])
-// CHECK:            %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, %[[S5]], %[[S6]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x3x5xf32> to tensor<6x6x2x2x?x?xf32>
-// CHECK:            %[[S7:.*]] = affine.apply #[[$MAP2]](%[[ARG2]])
-// CHECK:            %[[S8:.*]] = affine.apply #[[$MAP2]](%[[ARG4]])
-// CHECK:            %[[C2_10:.*]] = arith.constant 2 : index
-// CHECK:            %[[S9:.*]] = affine.apply #[[$MAP2]](%[[C2_10]])
-// CHECK:            %[[C2_11:.*]] = arith.constant 2 : index
-// CHECK:            %[[S10:.*]] = affine.apply #[[$MAP2]](%[[C2_11]])
-// CHECK:            %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor<?x?x?x?xf32>
+// CHECK: #[[$MAP3:.+]] = affine_map<() -> (8)>
+
+// CHECK-LABEL: func.func @tile_winograd_output(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<6x6x2x2x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<3x8x8x5xf32>
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C6_5:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_6:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2_8:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_9:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_10:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C2_11:.*]] = arith.constant 2 : index
+// CHECK:       %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C6]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<3x8x8x5xf32>) {
+// CHECK-NEXT:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C6_5]] step %[[C1_7]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<3x8x8x5xf32>) {
+// CHECK-NEXT:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_1]] to %[[C2]] step %[[C2_8]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<3x8x8x5xf32>) {
+// CHECK-NEXT:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_2]] to %[[C2_6]] step %[[C2_9]] iter_args(%[[ARG9:.*]] = %[[ARG7]]) -> (tensor<3x8x8x5xf32>) {
+// CHECK-NEXT:           %[[S5:.*]] = scf.for %[[ARG10:.*]] = %[[C0_3]] to %[[C3]] step %[[C2_10]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<3x8x8x5xf32>) {
+// CHECK-NEXT:             %[[S6:.*]] = scf.for %[[ARG12:.*]] = %[[C0_4]] to %[[C5]] step %[[C2_11]] iter_args(%[[ARG13:.*]] = %[[ARG11]]) -> (tensor<3x8x8x5xf32>) {
+// CHECK-NEXT:               %[[C3_12:.*]] = arith.constant 3 : index
+// CHECK-NEXT:               %[[S7:.*]] = affine.min #[[$MAP0]](%[[ARG10]])
+// CHECK-NEXT:               %[[C5_13:.*]] = arith.constant 5 : index
+// CHECK-NEXT:               %[[S8:.*]] = affine.min #[[$MAP1]](%[[ARG12]])
+// CHECK-NEXT:               %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG6]], %[[ARG8]], %[[ARG10]], %[[ARG12]]] [6, 6, 2, 2, %[[S7]], %[[S8]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x3x5xf32> to tensor<6x6x2x2x?x?xf32>
+// CHECK-NEXT:               %[[S9:.*]] = affine.apply #[[$MAP2]](%[[ARG6]])
+// CHECK-NEXT:               %[[S10:.*]] = affine.apply #[[$MAP2]](%[[ARG8]])
+// CHECK-NEXT:               %[[S11:.*]] = affine.apply #[[$MAP3]]()
+// CHECK-NEXT:               %[[S12:.*]] = affine.apply #[[$MAP3]]()
+// CHECK-NEXT:               %[[EXTRACTED_SLICE_14:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG10]], %[[S9]], %[[S10]], %[[ARG12]]] [%[[S7]], %[[S11]], %[[S12]], %[[S8]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor<?x?x?x?xf32>
+// CHECK-NEXT:               %[[S13:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x6x2x2x?x?xf32>) outs(%[[EXTRACTED_SLICE_14]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK-NEXT:               %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG6]])
+// CHECK-NEXT:               %[[S15:.*]] = affine.apply #[[$MAP2]](%[[ARG8]])
+// CHECK-NEXT:               %[[S16:.*]] = affine.apply #[[$MAP3]]()
+// CHECK-NEXT:               %[[S17:.*]] = affine.apply #[[$MAP3]]()
+// CHECK-NEXT:               %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG13]][%[[ARG10]], %[[S14]], %[[S15]], %[[ARG12]]] [%[[S7]], %[[S16]], %[[S17]], %[[S8]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> into tensor<3x8x8x5xf32>
+// CHECK-NEXT:               scf.yield %[[INSERTED_SLICE]] : tensor<3x8x8x5xf32>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[S6]] : tensor<3x8x8x5xf32>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           scf.yield %[[S5]] : tensor<3x8x8x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S4]] : tensor<3x8x8x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S3]] : tensor<3x8x8x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S2]] : tensor<3x8x8x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S1]] : tensor<3x8x8x5xf32>
+// CHECK-NEXT: }
 
 // -----
 
@@ -356,36 +552,63 @@ func.func @tile_winograd_output(%arg0 : tensor<6x1x2x1x3x5xf32>, %arg1: tensor<3
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %loop1:6 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<() -> (4)>
+
 // CHECK-LABEL: func.func @tile_winograd_output(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<6x1x2x1x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<3x8x1x5xf32>
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C0_3:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C0_5:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:   %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
 // CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C1_2:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C1_4:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C1_6:.*]] = arith.constant 1 : index
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]]
-// CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]]
-// CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]]
-// CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32>
-// CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
-// CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:           %[[C1_7:.*]] = arith.constant 1 : index
-// CHECK:           %[[S7:.*]] = affine.apply #[[$MAP0]](%[[C1_7]])
-// CHECK:           %[[C1_8:.*]] = arith.constant 1 : index
-// CHECK:           %[[S8:.*]] = affine.apply #[[$MAP0]](%[[C1_8]])
-// CHECK:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
-// CHECK:           %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>
+// CHECK-DAG:   %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_8:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_9:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_10:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C1_11:.*]] = arith.constant 1 : index
+// CHECK:       %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C6]] step %[[C1_6]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<3x8x1x5xf32>) {
+// CHECK-NEXT:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1]] step %[[C1_7]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<3x8x1x5xf32>) {
+// CHECK-NEXT:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_8]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<3x8x1x5xf32>) {
+// CHECK-NEXT:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_2]] to %[[C1_5]] step %[[C1_9]] iter_args(%[[ARG9:.*]] = %[[ARG7]]) -> (tensor<3x8x1x5xf32>) {
+// CHECK-NEXT:           %[[S5:.*]] = scf.for %[[ARG10:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_10]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<3x8x1x5xf32>) {
+// CHECK-NEXT:             %[[S6:.*]] = scf.for %[[ARG12:.*]] = %[[C0_4]] to %[[C5]] step %[[C1_11]] iter_args(%[[ARG13:.*]] = %[[ARG11]]) -> (tensor<3x8x1x5xf32>) {
+// CHECK-NEXT:               %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG6]], %[[ARG8]], %[[ARG10]], %[[ARG12]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32>
+// CHECK-NEXT:               %[[S7:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK-NEXT:               %[[S8:.*]] = affine.apply #[[$MAP0]](%[[ARG8]])
+// CHECK-NEXT:               %[[S9:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:               %[[S10:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:               %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG10]], %[[S7]], 0, %[[ARG12]]] [1, %[[S9]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
+// CHECK-NEXT:               %[[S11:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>
+// CHECK-NEXT:               %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK-NEXT:               %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG8]])
+// CHECK-NEXT:               %[[S14:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:               %[[S15:.*]] = affine.apply #[[$MAP1]]()
+// CHECK-NEXT:               %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[ARG13]][%[[ARG10]], %[[S12]], 0, %[[ARG12]]] [1, %[[S14]], 1, 1] [1, 1, 1, 1] : tensor<1x?x1x1xf32> into tensor<3x8x1x5xf32>
+// CHECK-NEXT:               scf.yield %[[INSERTED_SLICE]] : tensor<3x8x1x5xf32>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[S6]] : tensor<3x8x1x5xf32>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           scf.yield %[[S5]] : tensor<3x8x1x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S4]] : tensor<3x8x1x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S3]] : tensor<3x8x1x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S2]] : tensor<3x8x1x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S1]] : tensor<3x8x1x5xf32>



More information about the Mlir-commits mailing list