[Mlir-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)

Hsiangkai Wang llvmlistbot at llvm.org
Sat Aug 3 08:40:07 PDT 2024


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/96184

>From bad2ae08252a7d95e4655cf4fe080004a440ecf9 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:44:27 +0100
Subject: [PATCH 1/5] [mlir][linalg] Implement TilingInterface for winograd
 operations

In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operations. Before converting winograd
operations into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.

Add a transform operation structured.decompose_winograd_op to decompose
winograd operations. Before applying the transform op, use tile_using_for
to tile the input data into supported size. The test case shows how to
tile and decompose winograd operations.
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  22 +-
 .../Linalg/TransformOps/LinalgTransformOps.td |  37 +++
 .../Dialect/Linalg/Transforms/Transforms.h    |  45 +++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 216 +++++++++++++++
 .../TransformOps/LinalgTransformOps.cpp       |  41 +++
 .../Linalg/Transforms/WinogradConv2D.cpp      |  18 ++
 .../transform-tile-and-winograd-rewrite.mlir  | 260 ++++++++++++++++++
 7 files changed, 633 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a9007c8db3078..bc2df7082ef8e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,8 +154,8 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradFilterTransformOp :
-    Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
+    [AllElementTypesMatch<["filter", "output"]>]> {
   let summary = "Winograd filter transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -193,8 +193,13 @@ def Linalg_WinogradFilterTransformOp :
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradInputTransformOp :
-    Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
+    [AllElementTypesMatch<["input", "output"]>,
+     DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Winograd input transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -232,8 +237,13 @@ def Linalg_WinogradInputTransformOp :
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradOutputTransformOp :
-    Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
+    [AllElementTypesMatch<["value", "output"]>,
+     DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Winograd output transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index ecc86999006db..106f0d79d9792 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2697,4 +2697,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
   }];
 }
 
+def DecomposeWinogradOp : Op<Transform_Dialect,
+    "structured.decompose_winograd_op",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Decompose winograd operations. It will convert filter, input and output
+    transform operations into a combination of scf, tensor, and linalg
+    equivalent operations. Before applying this transform operations, users
+    need to tile winograd transform operations into supported sizes.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    operations.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 477ef7bfafb18..89f907cfc7ade 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1316,6 +1316,51 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
                                       linalg::Conv2DNhwcFhwcOp op, int64_t m,
                                       int64_t r);
 
+/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
+/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
+/// C. After the rewriting, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract filter<h x w> from filter<f x h x w x c>
+///     %ret = linalg.matmul G, %extracted
+///     %ret = linalg.matmul %ret, GT
+///     %inserted = insert %ret into filter<h x w x c x f>
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradFilterTransformOp op);
+
+/// Rewrite linalg.winograd_input_transform. The data layout of the input is
+/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
+/// C. After the rewriting, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract input<h x w> from input<n x h x w x c>
+///     %ret = linalg.matmul BT, %extracted
+///     %ret = linalg.matmul %ret, B
+///     %inserted = insert %ret into input<h x w x n x c>
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+                                  linalg::WinogradInputTransformOp op);
+
+/// Rewrite linalg.winograd_output_transform. The data layout of the output is
+/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
+/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
+/// F. After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %f = lo_f to hi_f step 1
+///     %extracted = extract input<h x w> from result<h x w x n x f>
+///     %ret = linalg.matmul AT, %extracted
+///     %ret = linalg.matmul %ret, A
+///     %inserted = insert %ret into ret<n x h x w x f>
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradOutputTransformOp op);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a101552e419bc..47800dfc1d1a2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2885,6 +2885,15 @@ LogicalResult WinogradFilterTransformOp::verify() {
 // WinogradInputTransformOp
 //===----------------------------------------------------------------------===//
 
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+                               Location loc) {
+  if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
+    auto intAttr = cast<IntegerAttr>(attr);
+    return builder.create<arith::ConstantOp>(loc, intAttr);
+  }
+  return opFoldResult.get<Value>();
+}
+
 LogicalResult WinogradInputTransformOp::verify() {
   auto inputType = cast<ShapedType>(getInput().getType());
   ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2922,6 +2931,113 @@ LogicalResult WinogradInputTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  auto indexType = builder.getIndexType();
+  auto zeroAttr = builder.getIntegerAttr(indexType, 0);
+  auto oneAttr = builder.getIntegerAttr(indexType, 1);
+  Value output = getOutput();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zeroAttr;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = oneAttr;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(offsets[2]);
+  resultOffsets.push_back(offsets[3]);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(sizes[0]);
+  resultSizes.push_back(sizes[1]);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(sizes[4]);
+  resultSizes.push_back(sizes[5]);
+
+  return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+                                                 ArrayRef<OpFoldResult> offsets,
+                                                 ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value input = getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  Location loc = getLoc();
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  auto context = builder.getContext();
+  auto affineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
+
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(mappedOffset1);
+  sliceOffsets.push_back(mappedOffset2);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(sizes[4]);
+  sliceSizes.push_back(alphaHAttr);
+  sliceSizes.push_back(alphaWAttr);
+  sliceSizes.push_back(sizes[5]);
+  SmallVector<OpFoldResult> inputStrides(4, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  SmallVector<OpFoldResult> outputStrides(6, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // WinogradOutputTransformOp
 //===----------------------------------------------------------------------===//
@@ -2964,6 +3080,106 @@ LogicalResult WinogradOutputTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  auto indexType = builder.getIndexType();
+  auto zeroAttr = builder.getIntegerAttr(indexType, 0);
+  auto oneAttr = builder.getIntegerAttr(indexType, 1);
+  Value value = getValue();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zeroAttr;
+    loopBounds[dim].size = getDimValue(builder, loc, value, dim);
+    loopBounds[dim].stride = oneAttr;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value output = getOutput();
+  auto outputType = cast<ShapedType>(output.getType());
+  auto outputShape = outputType.getShape();
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int64_t m = getM();
+  auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
+  auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
+
+  Location loc = getLoc();
+  auto context = builder.getContext();
+  auto affineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
+
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(mappedOffset1);
+  resultOffsets.push_back(mappedOffset2);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(sizes[4]);
+  resultSizes.push_back(heightM);
+  resultSizes.push_back(widthM);
+  resultSizes.push_back(sizes[5]);
+  return success();
+}
+
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Location loc = getLoc();
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(offsets[2]);
+  sliceOffsets.push_back(offsets[3]);
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(sizes[0]);
+  sliceSizes.push_back(sizes[1]);
+  sliceSizes.push_back(oneAttr);
+  sliceSizes.push_back(oneAttr);
+  sliceSizes.push_back(sizes[4]);
+  sliceSizes.push_back(sizes[5]);
+  SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  SmallVector<OpFoldResult> strides(4, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9baf358a95503..325dc04a828f0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3853,6 +3853,47 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  FailureOr<Operation *> maybeTransformed = failure();
+  bool supported =
+      TypeSwitch<Operation *, bool>(target)
+          .Case([&](linalg::WinogradFilterTransformOp op) {
+            maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
+            return true;
+          })
+          .Case([&](linalg::WinogradInputTransformOp op) {
+            maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
+            return true;
+          })
+          .Case([&](linalg::WinogradOutputTransformOp op) {
+            maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
+            return true;
+          })
+          .Default([&](Operation *op) { return false; });
+
+  if (!supported) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "this operation is not supported to decompose into other operations";
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+
+  if (supported && failed(maybeTransformed)) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "decompose Winograd operations failed";
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index c6c770e2781ff..c01a5f4b055f1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -1169,6 +1169,24 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
   return winogradConv2DHelper(rewriter, op, m, r);
 }
 
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradFilterTransformOp op) {
+  return decomposeWinogradFilterTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+                                  linalg::WinogradInputTransformOp op) {
+  return decomposeWinogradInputTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradOutputTransformOp op) {
+  return decomposeWinogradOutputTransformHelper(rewriter, op);
+}
+
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
new file mode 100644
index 0000000000000..29833324bb21b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -0,0 +1,260 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<6x6x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %2 = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+  %4 = tensor.empty() : tensor<36x8x2xf32>
+  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %6 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-DAG:    %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK:        scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:     scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG3]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK:            linalg.matmul
+// CHECK:            linalg.matmul
+// CHECK:        %[[S5:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK:        %[[S6:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK:        scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT:     scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK:            %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:            %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:            tensor.extract_slice %[[ARG0]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
+// CHECK:            %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK:            %[[S15:.*]] = scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK-NEXt:         scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG8]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK:                linalg.matmul
+// CHECK:                linalg.matmul
+// CHECK:          %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
+// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
+// CHECK:       %[[S9:.*]] = linalg.batch_matmul
+// CHECK:       %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK:   %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK:   scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:     scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK:       tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK:       %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:       %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
+// CHECK:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:         scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK:           linalg.matmul
+// CHECK:           linalg.matmul
+// CHECK:       %[[S16:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:       %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
+// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<6x6x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+  ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
+    tensor.yield %cst : f32
+  } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+  %2 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+  %4 = tensor.empty() : tensor<36x18x2xf32>
+  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+  %padded_1 = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+  ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
+    tensor.yield %cst : f32
+  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %extracted_slice = tensor.extract_slice %6[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+  return %extracted_slice : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK:   scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK:       %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK:       tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK:       linalg.matmul
+// CHECK:       %[[S13:.*]] = linalg.matmul
+// CHECK:       %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK:       %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK:       scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x5x2xf32>
+// CHECK:   %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK:   ^bb0({{.*}}):
+// CHECK:     tensor.yield %[[CST_6]] : f32
+// CHECK:   } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK:   %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK:   scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK:       %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:       %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:       tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32>
+// CHECK:       %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK:       %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK:         scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK:           linalg.matmul
+// CHECK:           %[[S17:.*]] = linalg.matmul
+// CHECK:           %[[S18:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+// CHECK:           %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S17]] into %[[S18]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x1x1xf32>
+// CHECK:           %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x1x1x2x5xf32>
+// CHECK:           scf.yield %[[INSERTED_SLICE_14]] : tensor<6x6x1x1x2x5xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x3x3x2x5xf32>
+// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
+// CHECK:   %[[S6:.*]] = linalg.batch_matmul
+// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK:   %[[PADDED_8:.*]] = tensor.pad %[[ARG3]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK:   ^bb0({{.*}}):
+// CHECK:     tensor.yield %[[CST_6]] : f32
+// CHECK:   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+// CHECK:   %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK:   %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK:       tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK:       %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:       %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:       %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32>
+// CHECK:       %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK:         scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK:           linalg.matmul
+// CHECK:           linalg.matmul
+// CHECK:           %[[S22:.*]] = linalg.mul
+// CHECK:           %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK:           %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK:           %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK:           scf.yield %[[INSERTED_SLICE_14]] : tensor<2x4x4x2xf32>
+// CHECK:       %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:       %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32>
+// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
+// CHECK:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK: }
+
+// -----
+
+func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+  %0 = tensor.empty() : tensor<6x1x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+  %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+  %4 = tensor.empty() : tensor<6x2x2xf32>
+  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %6 : tensor<2x4x1x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK:  %[[C5:.*]] = arith.constant 5 : index
+// CHECK:  %[[C2:.*]] = arith.constant 2 : index
+// CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
+// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<1x3x1x1xf32> to tensor<3x1xf32>
+// CHECK:      %[[S9:.*]] = linalg.matmul
+// CHECK:      %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
+// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x5x2xf32>
+// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<2x6x1x5xf32> to tensor<1x6x1x1xf32>
+// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> to tensor<6x1xf32>
+// CHECK:      %[[S9:.*]] = linalg.matmul
+// CHECK:      %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> into tensor<6x1x1x1x2x5xf32>
+// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x1x1x2x5xf32>
+// CHECK:  %[[S5:.*]] = linalg.batch_matmul
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x2x2xf32> to tensor<6x1x1x1x1x1xf32>
+// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> to tensor<6x1xf32>
+// CHECK:      linalg.matmul
+// CHECK:      %[[S12:.*]] = linalg.mul
+// CHECK:      %[[S13:.*]] = tensor.empty() : tensor<1x4x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[S13]][0, 0, 0, 0] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<4x1xf32> into tensor<1x4x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<1x4x1x1xf32> into tensor<2x4x1x2xf32>
+// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<2x4x1x2xf32>

>From 0c121bbef5f970a6ebc632c4d791ba535f2c29f7 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 17 Jul 2024 18:26:14 +0100
Subject: [PATCH 2/5] Address ftynse's comments

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 122 +++----
 .../transform-tile-and-winograd-rewrite.mlir  | 332 ++++++++++--------
 2 files changed, 226 insertions(+), 228 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 47800dfc1d1a2..4905e3f8283f8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2885,15 +2885,6 @@ LogicalResult WinogradFilterTransformOp::verify() {
 // WinogradInputTransformOp
 //===----------------------------------------------------------------------===//
 
-Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
-                               Location loc) {
-  if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
-    auto intAttr = cast<IntegerAttr>(attr);
-    return builder.create<arith::ConstantOp>(loc, intAttr);
-  }
-  return opFoldResult.get<Value>();
-}
-
 LogicalResult WinogradInputTransformOp::verify() {
   auto inputType = cast<ShapedType>(getInput().getType());
   ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2934,9 +2925,9 @@ LogicalResult WinogradInputTransformOp::verify() {
 SmallVector<Range>
 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
   Location loc = getLoc();
-  auto indexType = builder.getIndexType();
-  auto zeroAttr = builder.getIntegerAttr(indexType, 0);
-  auto oneAttr = builder.getIntegerAttr(indexType, 1);
+  IndexType indexType = builder.getIndexType();
+  IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
+  IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
   Value output = getOutput();
   SmallVector<Range> loopBounds(6);
   for (unsigned dim = 0; dim < 6; ++dim) {
@@ -2958,21 +2949,13 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
-  auto zeroAttr = builder.getI64IntegerAttr(0);
-  auto oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
 
-  resultOffsets.push_back(zeroAttr);
-  resultOffsets.push_back(zeroAttr);
-  resultOffsets.push_back(offsets[2]);
-  resultOffsets.push_back(offsets[3]);
-  resultOffsets.push_back(zeroAttr);
-  resultOffsets.push_back(zeroAttr);
-  resultSizes.push_back(sizes[0]);
-  resultSizes.push_back(sizes[1]);
-  resultSizes.push_back(oneAttr);
-  resultSizes.push_back(oneAttr);
-  resultSizes.push_back(sizes[4]);
-  resultSizes.push_back(sizes[5]);
+  resultOffsets.append(
+      {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+  resultSizes.append(
+      {sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
 
   return success();
 }
@@ -2981,11 +2964,11 @@ FailureOr<TilingResult>
 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
-  auto oneAttr = builder.getI64IntegerAttr(1);
-  auto zeroAttr = builder.getI64IntegerAttr(0);
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
   Value input = getInput();
   auto inputType = cast<ShapedType>(input.getType());
-  auto inputShape = inputType.getShape();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
   int64_t inputH = inputShape[1];
   int64_t inputW = inputShape[2];
   int64_t m = getM();
@@ -2993,29 +2976,25 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   int64_t alpha = m + r - 1;
   int64_t alphaH = inputH != 1 ? alpha : 1;
   int64_t alphaW = inputW != 1 ? alpha : 1;
-  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
-  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
   Location loc = getLoc();
   SmallVector<Value> tiledOperands;
   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
-  auto context = builder.getContext();
+  MLIRContext *context = builder.getContext();
   auto affineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
   Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+      loc, affineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
   Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
-
-  sliceOffsets.push_back(zeroAttr);
-  sliceOffsets.push_back(mappedOffset1);
-  sliceOffsets.push_back(mappedOffset2);
-  sliceOffsets.push_back(zeroAttr);
-  sliceSizes.push_back(sizes[4]);
-  sliceSizes.push_back(alphaHAttr);
-  sliceSizes.push_back(alphaWAttr);
-  sliceSizes.push_back(sizes[5]);
+      loc, affineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+
+  sliceOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
+  sliceSizes.append({sizes[4], alphaHAttr, alphaWAttr, sizes[5]});
   SmallVector<OpFoldResult> inputStrides(4, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
@@ -3030,7 +3009,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
 
-  SmallVector<Type, 4> resultTypes;
+  SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
@@ -3083,9 +3062,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
 SmallVector<Range>
 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
   Location loc = getLoc();
-  auto indexType = builder.getIndexType();
-  auto zeroAttr = builder.getIntegerAttr(indexType, 0);
-  auto oneAttr = builder.getIntegerAttr(indexType, 1);
+  IndexType indexType = builder.getIndexType();
+  IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
+  IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
   Value value = getValue();
   SmallVector<Range> loopBounds(6);
   for (unsigned dim = 0; dim < 6; ++dim) {
@@ -3107,57 +3086,44 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
-  auto zeroAttr = builder.getI64IntegerAttr(0);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
   Value output = getOutput();
   auto outputType = cast<ShapedType>(output.getType());
-  auto outputShape = outputType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
   int64_t outputH = outputShape[1];
   int64_t outputW = outputShape[2];
   int64_t m = getM();
-  auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
-  auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
+  IntegerAttr heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
+  IntegerAttr widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
 
   Location loc = getLoc();
-  auto context = builder.getContext();
+  MLIRContext *context = builder.getContext();
   auto affineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
   Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+      loc, affineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
   Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
-
-  resultOffsets.push_back(zeroAttr);
-  resultOffsets.push_back(mappedOffset1);
-  resultOffsets.push_back(mappedOffset2);
-  resultOffsets.push_back(zeroAttr);
-  resultSizes.push_back(sizes[4]);
-  resultSizes.push_back(heightM);
-  resultSizes.push_back(widthM);
-  resultSizes.push_back(sizes[5]);
+      loc, affineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+
+  resultOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
+  resultSizes.append({sizes[4], heightM, widthM, sizes[5]});
   return success();
 }
 
 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
     OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes) {
-  auto oneAttr = builder.getI64IntegerAttr(1);
-  auto zeroAttr = builder.getI64IntegerAttr(0);
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
   Location loc = getLoc();
   SmallVector<Value> tiledOperands;
   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
-  sliceOffsets.push_back(zeroAttr);
-  sliceOffsets.push_back(zeroAttr);
-  sliceOffsets.push_back(offsets[2]);
-  sliceOffsets.push_back(offsets[3]);
-  sliceOffsets.push_back(zeroAttr);
-  sliceOffsets.push_back(zeroAttr);
-  sliceSizes.push_back(sizes[0]);
-  sliceSizes.push_back(sizes[1]);
-  sliceSizes.push_back(oneAttr);
-  sliceSizes.push_back(oneAttr);
-  sliceSizes.push_back(sizes[4]);
-  sliceSizes.push_back(sizes[5]);
+  sliceOffsets.append(
+      {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+  sliceSizes.append({sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
   SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
@@ -3172,7 +3138,7 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getOutput(), sliceOffsets, sliceSizes, strides));
 
-  SmallVector<Type, 4> resultTypes;
+  SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 29833324bb21b..6bb3fb1423edc 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -31,52 +31,77 @@ module attributes {transform.with_named_sequence} {
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func.func @conv2d
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK-DAG:    %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK:        scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK-NEXT:     scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG3]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK:            linalg.matmul
-// CHECK:            linalg.matmul
-// CHECK:        %[[S5:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:        %[[S6:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:        scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<6x6x2x2x2x5xf32>) {
-// CHECK-NEXT:     scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x2x2x5xf32>) {
-// CHECK:            %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:            %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:            tensor.extract_slice %[[ARG0]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
-// CHECK:            %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK:            %[[S15:.*]] = scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK-NEXt:         scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG8]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK:                linalg.matmul
-// CHECK:                linalg.matmul
-// CHECK:          %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
-// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
-// CHECK:       %[[S9:.*]] = linalg.batch_matmul
-// CHECK:       %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK:   %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK:   scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
-// CHECK-NEXT:     scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
-// CHECK:       tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
-// CHECK:       %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:       %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
-// CHECK:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:         scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK:           linalg.matmul
-// CHECK:           linalg.matmul
-// CHECK:       %[[S16:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:       %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
-// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+// CHECK:  %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK:  %[[C5:.*]] = arith.constant 5 : index
+// CHECK:  %[[C2:.*]] = arith.constant 2 : index
+// CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[S0:.*]] = tensor.empty()
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK:      %[[S11:.*]] = linalg.matmul
+// CHECK:      %[[S13:.*]] = linalg.matmul
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]]
+// CHECK:    scf.yield %[[S9]]
+// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK:        %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK:          %[[S15:.*]] = linalg.matmul
+// CHECK:          %[[S17:.*]] = linalg.matmul
+// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S17]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
+// CHECK:        scf.yield %[[S13]]
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]]
+// CHECK:    scf.yield %[[S9]]
+// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[S6:.*]] = linalg.batch_matmul
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
+// CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK:        %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S17:.*]] = linalg.matmul
+// CHECK:          %[[S19:.*]] = linalg.matmul
+// CHECK:          %[[S20:.*]] = tensor.empty()
+// CHECK:          %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:            linalg.yield %[[IN]] : f32
+// CHECK:          } -> tensor<4x4xf32>
+// CHECK:          %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
+// CHECK:        scf.yield %[[S15]]
+// CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]]
+// CHECK:    scf.yield %[[S9]]
 
 // -----
 
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<6x6x5x2xf32>
   %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
@@ -91,7 +116,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
   %4 = tensor.empty() : tensor<36x18x2xf32>
   %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
   %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
-  %padded_1 = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+  %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
   ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
     tensor.yield %cst : f32
   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
@@ -117,80 +142,82 @@ module attributes {transform.with_named_sequence} {
 }
 
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK:   scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK:       %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
-// CHECK:       tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
-// CHECK:       linalg.matmul
-// CHECK:       %[[S13:.*]] = linalg.matmul
-// CHECK:       %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
-// CHECK:       %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
-// CHECK:       scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x5x2xf32>
-// CHECK:   %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
-// CHECK:   ^bb0({{.*}}):
-// CHECK:     tensor.yield %[[CST_6]] : f32
-// CHECK:   } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
-// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:   %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:   scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK:       %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:       %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:       tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32>
-// CHECK:       %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK:       %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK:         scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK:           linalg.matmul
-// CHECK:           %[[S17:.*]] = linalg.matmul
-// CHECK:           %[[S18:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
-// CHECK:           %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S17]] into %[[S18]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x1x1xf32>
-// CHECK:           %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x1x1x2x5xf32>
-// CHECK:           scf.yield %[[INSERTED_SLICE_14]] : tensor<6x6x1x1x2x5xf32>
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x3x3x2x5xf32>
-// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
-// CHECK:   %[[S6:.*]] = linalg.batch_matmul
-// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
-// CHECK:   %[[PADDED_8:.*]] = tensor.pad %[[ARG3]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
-// CHECK:   ^bb0({{.*}}):
-// CHECK:     tensor.yield %[[CST_6]] : f32
-// CHECK:   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK:   %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK:   %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK:     scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK:       tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6x1x1x2x2xf32>
-// CHECK:       %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:       %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:       %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32>
-// CHECK:       %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK:         scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK:           linalg.matmul
-// CHECK:           linalg.matmul
-// CHECK:           %[[S22:.*]] = linalg.mul
-// CHECK:           %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
-// CHECK:           %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
-// CHECK:           %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
-// CHECK:           scf.yield %[[INSERTED_SLICE_14]] : tensor<2x4x4x2xf32>
-// CHECK:       %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:       %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32>
-// CHECK:       scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
-// CHECK:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK: }
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK:  %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK:  %[[C3:.*]] = arith.constant 3 : index
+// CHECK:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK:  %[[C5:.*]] = arith.constant 5 : index
+// CHECK:  %[[C2:.*]] = arith.constant 2 : index
+// CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[S0:.*]] = tensor.empty()
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK:      %[[S11:.*]] = linalg.matmul
+// CHECK:      %[[S13:.*]] = linalg.matmul
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
+// CHECK:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK:        %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK:          %[[S15:.*]] = linalg.matmul
+// CHECK:          %[[S17:.*]] = linalg.matmul
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S17]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
+// CHECK:        scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]]
+// CHECK:    scf.yield %[[S9]]
+// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[S6:.*]] = linalg.batch_matmul
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
+// CHECK:  %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK:  %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK:        %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S17:.*]] = linalg.matmul
+// CHECK:          %[[S19:.*]] = linalg.matmul
+// CHECK:          %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK:          %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:            linalg.yield %[[IN]] : f32
+// CHECK:          } -> tensor<4x4xf32>
+// CHECK:          %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S22]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_12]]
+// CHECK:        scf.yield %[[S15]] : tensor<2x4x4x2xf32>
+// CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      scf.yield %[[INSERTED_SLICE]]
+// CHECK:    scf.yield %[[S9]]
+// CHECK:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
+// CHECK:  return %[[EXTRACTED_SLICE]]
 
 // -----
 
-func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
   %0 = tensor.empty() : tensor<6x1x5x2xf32>
   %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
   %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
@@ -200,7 +227,7 @@ func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
   %4 = tensor.empty() : tensor<6x2x2xf32>
   %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
   %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
   return %6 : tensor<2x4x1x2xf32>
 }
 
@@ -220,41 +247,46 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func.func @conv2d_mx1_rx1
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK:  %[[C1:.*]] = arith.constant 1 : index
-// CHECK:  %[[C5:.*]] = arith.constant 5 : index
-// CHECK:  %[[C2:.*]] = arith.constant 2 : index
-// CHECK:  %[[C0:.*]] = arith.constant 0 : index
-// CHECK:  %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x1x5x2xf32>) {
-// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x5x2xf32>) {
-// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
-// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<1x3x1x1xf32> to tensor<3x1xf32>
-// CHECK:      %[[S9:.*]] = linalg.matmul
-// CHECK:      %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
-// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x5x2xf32>
-// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]]) -> (tensor<6x1x1x1x2x5xf32>) {
-// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x1x1x2x5xf32>) {
-// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<2x6x1x5xf32> to tensor<1x6x1x1xf32>
-// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> to tensor<6x1xf32>
-// CHECK:      %[[S9:.*]] = linalg.matmul
-// CHECK:      %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> into tensor<6x1x1x1x2x5xf32>
-// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x1x1x2x5xf32>
-// CHECK:  %[[S5:.*]] = linalg.batch_matmul
-// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK:  scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<2x4x1x2xf32>) {
-// CHECK:    scf.for %[[ARG6:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x4x1x2xf32>) {
-// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x2x2xf32> to tensor<6x1x1x1x1x1xf32>
-// CHECK:      tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> to tensor<6x1xf32>
-// CHECK:      linalg.matmul
-// CHECK:      %[[S12:.*]] = linalg.mul
-// CHECK:      %[[S13:.*]] = tensor.empty() : tensor<1x4x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[S13]][0, 0, 0, 0] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<4x1xf32> into tensor<1x4x1x1xf32>
-// CHECK:      %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<1x4x1x1xf32> into tensor<2x4x1x2xf32>
-// CHECK:      scf.yield %[[INSERTED_SLICE_5]] : tensor<2x4x1x2xf32>
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S9:.*]] = linalg.matmul
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE]]
+// CHECK:     scf.yield %[[S7]]
+// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
+// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S9:.*]] = linalg.matmul
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE]]
+// CHECK:     scf.yield %[[S7]]
+// CHECK:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK:   %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:   %[[S5:.*]] = linalg.batch_matmul
+// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
+// CHECK:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:       %[[S9:.*]] = linalg.matmul
+// CHECK:       %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK:       %[[S11:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S10]] : tensor<4x1xf32>) {
+// CHECK:       ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:         linalg.yield %[[IN]] : f32
+// CHECK:       } -> tensor<4x1xf32>
+// CHECK:       %[[S12:.*]] = linalg.mul ins(%[[S11]], %[[S9]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE]]
+// CHECK:     scf.yield %[[S7]]
+// CHECK:   return %[[S6]]

>From b4e5c7c97b770391c81839261d47956925e23daf Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 18 Jul 2024 06:52:17 +0100
Subject: [PATCH 3/5] Address more comments

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 18 ++++++++----------
 1 file changed, 8 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4905e3f8283f8..da55360c82eb0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2999,15 +2999,14 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
 
-  sliceOffsets.clear();
-  sliceSizes.clear();
-  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
-                                   sliceSizes)))
+  SmallVector<OpFoldResult> resultOffsets, resultSizes;
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+                                   resultSizes)))
     return failure();
 
   SmallVector<OpFoldResult> outputStrides(6, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
+      loc, getOutput(), resultOffsets, resultSizes, outputStrides));
 
   SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
@@ -3128,15 +3127,14 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
 
-  sliceOffsets.clear();
-  sliceSizes.clear();
-  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
-                                   sliceSizes)))
+  SmallVector<OpFoldResult> resultOffsets, resultSizes;
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+                                   resultSizes)))
     return failure();
 
   SmallVector<OpFoldResult> strides(4, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+      loc, getOutput(), resultOffsets, resultSizes, strides));
 
   SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());

>From e972232669e6de1638a86b62808cff7290ff3e70 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 18 Jul 2024 07:24:42 +0100
Subject: [PATCH 4/5] Update comments

---
 .../Dialect/Linalg/Transforms/Transforms.h    | 52 ++++++++++++-------
 1 file changed, 32 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 89f907cfc7ade..861e14d22d962 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1318,8 +1318,8 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
 
 /// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
 /// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
-/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
-/// C. After the rewriting, we get
+/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
+/// the rewriting, we get
 ///
 /// scf.for %f = lo_f to hi_f step 1
 ///   scf.for %c = lo_c to hi_c step 1
@@ -1333,30 +1333,42 @@ decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
 
 /// Rewrite linalg.winograd_input_transform. The data layout of the input is
 /// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
-/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
-/// C. After the rewriting, we get
-///
-/// scf.for %n = lo_n to hi_n step 1
-///   scf.for %c = lo_c to hi_c step 1
-///     %extracted = extract input<h x w> from input<n x h x w x c>
-///     %ret = linalg.matmul BT, %extracted
-///     %ret = linalg.matmul %ret, B
-///     %inserted = insert %ret into input<h x w x n x c>
+/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
+/// and tileW. After the rewriting, we get
+///
+/// scf.for %h = 0 to tileH step 1
+///   scf.for %w = 0 to tileW step 1
+///     scf.for %n = 0 to N step 1
+///       scf.for %c = 0 to C step 1
+///         %extracted = extract %extracted<alphaH x alphaW> from
+///                              %input<N x H x W x C>
+///                              at [%n, (%h x m), (%w x m), %c]
+///         %ret = linalg.matmul BT, %extracted
+///         %ret = linalg.matmul %ret, B
+///         %inserted = insert %ret<alphaH x alphaW> into
+///                            %output<alphaH x alphaW x tileH x tileW x N x C>
+///                            at [0, 0, %h, %w, %n, %c]
 FailureOr<Operation *>
 decomposeWinogradInputTransformOp(RewriterBase &rewriter,
                                   linalg::WinogradInputTransformOp op);
 
 /// Rewrite linalg.winograd_output_transform. The data layout of the output is
 /// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
-/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
-/// F. After the transformation, we get
-///
-/// scf.for %n = lo_n to hi_n step 1
-///   scf.for %f = lo_f to hi_f step 1
-///     %extracted = extract input<h x w> from result<h x w x n x f>
-///     %ret = linalg.matmul AT, %extracted
-///     %ret = linalg.matmul %ret, A
-///     %inserted = insert %ret into ret<n x h x w x f>
+/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
+/// and tileW. After the transformation, we get
+///
+/// scf.for %h = 0 to tileH step 1
+///   scf.for %w = 0 to tileW step 1
+///     scf.for %n = 0 to N step 1
+///       scf.for %f = 0 to F step 1
+///         %extracted = extract %extracted<alphaH x alphaW> from
+///                              %input<alphaH x alphaW x tileH x tileW x N x F>
+///                              at [0, 0, %h, %w, %n, %f]
+///         %ret = linalg.matmul AT, %extracted
+///         %ret = linalg.matmul %ret, A
+///         %inserted = insert %ret<alphaH x alphaW> into
+///                            output<N x H x W x F>
+///                            at [%n, (%h x m), (%w x m), %f]
 FailureOr<Operation *>
 decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
                                    linalg::WinogradOutputTransformOp op);

>From 73696eec36aa8cebc80cf093b995a7fb34064951 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 1 Aug 2024 08:44:10 +0100
Subject: [PATCH 5/5] address Max191's comments

---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  49 ++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 279 ++++++++++---
 .../Linalg/Transforms/WinogradConv2D.cpp      |   7 -
 .../transform-tile-and-winograd-rewrite.mlir  |  12 +-
 .../Linalg/transform-tile-winograd.mlir       | 392 ++++++++++++++++++
 5 files changed, 664 insertions(+), 75 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/transform-tile-winograd.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index bc2df7082ef8e..aaf243e015726 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -155,7 +155,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
 }
 
 def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
-    [AllElementTypesMatch<["filter", "output"]>]> {
+    [AllElementTypesMatch<["filter", "output"]>,
+     DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Winograd filter transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -190,6 +195,20 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
   }];
+  let extraClassDeclaration = [{
+    ShapedType getFilterOperandType() {
+      return cast<ShapedType>(getFilter().getType());
+    }
+    ShapedType getOutputOperandType() {
+      return cast<ShapedType>(getOutput().getType());
+    }
+    int64_t getFilterOperandRank() {
+      return getFilterOperandType().getRank();
+    }
+    int64_t getOutputOperandRank() {
+      return getOutputOperandType().getRank();
+    }
+  }];
   let hasVerifier = 1;
 }
 
@@ -234,6 +253,20 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
   }];
+  let extraClassDeclaration = [{
+    ShapedType getInputOperandType() {
+      return cast<ShapedType>(getInput().getType());
+    }
+    ShapedType getOutputOperandType() {
+      return cast<ShapedType>(getOutput().getType());
+    }
+    int64_t getInputOperandRank() {
+      return getInputOperandType().getRank();
+    }
+    int64_t getOutputOperandRank() {
+      return getOutputOperandType().getRank();
+    }
+  }];
   let hasVerifier = 1;
 }
 
@@ -278,6 +311,20 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
   }];
+  let extraClassDeclaration = [{
+    ShapedType getValueOperandType() {
+      return cast<ShapedType>(getValue().getType());
+    }
+    ShapedType getOutputOperandType() {
+      return cast<ShapedType>(getOutput().getType());
+    }
+    int64_t getValueOperandRank() {
+      return getValueOperandType().getRank();
+    }
+    int64_t getOutputOperandRank() {
+      return getOutputOperandType().getRank();
+    }
+  }];
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index da55360c82eb0..ee633a715cc9e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2881,6 +2881,105 @@ LogicalResult WinogradFilterTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  IntegerAttr zeroAttr = builder.getIndexAttr(0);
+  IntegerAttr oneAttr = builder.getIndexAttr(1);
+  Value filter = getFilter();
+  SmallVector<Range> loopBounds;
+  Range fRange;
+  fRange.offset = zeroAttr;
+  // 0 is the index of F in the filter.
+  fRange.size = getDimValue(builder, loc, filter, 0);
+  fRange.stride = oneAttr;
+  loopBounds.emplace_back(fRange);
+  Range cRange;
+  cRange.offset = zeroAttr;
+  // 3 is the index of C in the filter.
+  cRange.size = getDimValue(builder, loc, filter, 3);
+  cRange.stride = oneAttr;
+  loopBounds.emplace_back(cRange);
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradFilterTransformOp::getLoopIteratorTypes() {
+  int64_t filterRank = getFilterOperandRank();
+  SmallVector<utils::IteratorType> iteratorTypes(filterRank - 2,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradFilterTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  ShapedType filterType = getFilterOperandType();
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = filterH != 1 ? alpha : 1;
+  int64_t alphaW = filterW != 1 ? alpha : 1;
+  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  resultOffsets.append({zeroAttr, zeroAttr, offsets[1], offsets[0]});
+  resultSizes.append({alphaHAttr, alphaWAttr, sizes[1], sizes[0]});
+
+  return success();
+}
+
+/// Implement tiling for winograd_filter_transform
+/// The input of winograd_filter_transform is (F, KH, KW, C).
+/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
+/// Users can specify the tile sizes of F and C.
+/// `offsets` are the values for the offsets of F, C for one tile.
+/// `sizes` are the values for the sizes of F, C for one tile.
+FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  ShapedType filterType = getFilterOperandType();
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
+  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  sliceOffsets.append({offsets[0], zeroAttr, zeroAttr, offsets[1]});
+  sliceSizes.append({sizes[0], filterHAttr, filterWAttr, sizes[1]});
+  int64_t filterRank = getFilterOperandRank();
+  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
+  Location loc = getLoc();
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
+
+  SmallVector<OpFoldResult> resultOffsets, resultSizes;
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+                                   resultSizes)))
+    return failure();
+
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+
+  SmallVector<Type> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // WinogradInputTransformOp
 //===----------------------------------------------------------------------===//
@@ -2925,22 +3024,23 @@ LogicalResult WinogradInputTransformOp::verify() {
 SmallVector<Range>
 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
   Location loc = getLoc();
-  IndexType indexType = builder.getIndexType();
-  IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
-  IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
+  IntegerAttr zeroAttr = builder.getIndexAttr(0);
+  IntegerAttr oneAttr = builder.getIndexAttr(1);
   Value output = getOutput();
-  SmallVector<Range> loopBounds(6);
-  for (unsigned dim = 0; dim < 6; ++dim) {
-    loopBounds[dim].offset = zeroAttr;
-    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
-    loopBounds[dim].stride = oneAttr;
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<Range> loopBounds(outputRank - 2);
+  for (unsigned dim = 2; dim < outputRank; ++dim) {
+    loopBounds[dim - 2].offset = zeroAttr;
+    loopBounds[dim - 2].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim - 2].stride = oneAttr;
   }
   return loopBounds;
 }
 
 SmallVector<utils::IteratorType>
 WinogradInputTransformOp::getLoopIteratorTypes() {
-  SmallVector<utils::IteratorType> iteratorTypes(6,
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<utils::IteratorType> iteratorTypes(outputRank - 2,
                                                  utils::IteratorType::parallel);
   return iteratorTypes;
 }
@@ -2950,52 +3050,79 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  ShapedType inputType = getInputOperandType();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
   resultOffsets.append(
-      {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+      {zeroAttr, zeroAttr, offsets[0], offsets[1], offsets[2], offsets[3]});
   resultSizes.append(
-      {sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
+      {alphaHAttr, alphaWAttr, sizes[0], sizes[1], sizes[2], sizes[3]});
 
   return success();
 }
 
+/// Implement tiling for winograd_input_transform
+/// The input of winograd_input_transform is (N, H, W, C).
+/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
+/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
+/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
+/// the values for the sizes of tileH, tileW, N, C for one tile.
 FailureOr<TilingResult>
 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  Value input = getInput();
-  auto inputType = cast<ShapedType>(input.getType());
+  ShapedType inputType = getInputOperandType();
   ArrayRef<int64_t> inputShape = inputType.getShape();
   int64_t inputH = inputShape[1];
   int64_t inputW = inputShape[2];
   int64_t m = getM();
   int64_t r = getR();
-  int64_t alpha = m + r - 1;
-  int64_t alphaH = inputH != 1 ? alpha : 1;
-  int64_t alphaW = inputW != 1 ? alpha : 1;
-  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
-  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
   Location loc = getLoc();
-  SmallVector<Value> tiledOperands;
-  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
-
   MLIRContext *context = builder.getContext();
-  auto affineMap =
+  auto offsetAffineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
-  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
-      loc, affineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+  Value mappedOffsetH = builder.create<affine::AffineApplyOp>(
+      loc, offsetAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[0]));
+  Value mappedOffsetW = builder.create<affine::AffineApplyOp>(
+      loc, offsetAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[1]));
+  auto sizeAffineMap = AffineMap::get(
+      1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
+  Value mappedSizeH = builder.create<affine::AffineApplyOp>(
+      loc, sizeAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, sizes[0]));
+  Value mappedSizeW = builder.create<affine::AffineApplyOp>(
+      loc, sizeAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, sizes[1]));
+
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
-  sliceOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
-  sliceSizes.append({sizes[4], alphaHAttr, alphaWAttr, sizes[5]});
-  SmallVector<OpFoldResult> inputStrides(4, oneAttr);
+  OpFoldResult offsetH =
+      inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
+  OpFoldResult offsetW =
+      inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
+  sliceOffsets.append({offsets[2], offsetH, offsetW, offsets[3]});
+  OpFoldResult sizeH =
+      inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+  OpFoldResult sizeW =
+      inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+  sliceSizes.append({sizes[2], sizeH, sizeW, sizes[3]});
+  int64_t inputRank = getInputOperandRank();
+  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
 
@@ -3004,7 +3131,8 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                    resultSizes)))
     return failure();
 
-  SmallVector<OpFoldResult> outputStrides(6, oneAttr);
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getOutput(), resultOffsets, resultSizes, outputStrides));
 
@@ -3032,7 +3160,8 @@ LogicalResult WinogradOutputTransformOp::verify() {
   bool leftTransform = valueH != 1;
   bool rightTransform = valueW != 1;
 
-  SmallVector<int64_t> expectedOutputShape(4, valueH);
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
   if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
     expectedOutputShape[1] = ShapedType::kDynamic;
   } else {
@@ -3061,22 +3190,23 @@ LogicalResult WinogradOutputTransformOp::verify() {
 SmallVector<Range>
 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
   Location loc = getLoc();
-  IndexType indexType = builder.getIndexType();
-  IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
-  IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
+  IntegerAttr zeroAttr = builder.getIndexAttr(0);
+  IntegerAttr oneAttr = builder.getIndexAttr(1);
   Value value = getValue();
-  SmallVector<Range> loopBounds(6);
-  for (unsigned dim = 0; dim < 6; ++dim) {
-    loopBounds[dim].offset = zeroAttr;
-    loopBounds[dim].size = getDimValue(builder, loc, value, dim);
-    loopBounds[dim].stride = oneAttr;
+  int64_t valueRank = getValueOperandRank();
+  SmallVector<Range> loopBounds(valueRank - 2);
+  for (unsigned dim = 2; dim < valueRank; ++dim) {
+    loopBounds[dim - 2].offset = zeroAttr;
+    loopBounds[dim - 2].size = getDimValue(builder, loc, value, dim);
+    loopBounds[dim - 2].stride = oneAttr;
   }
   return loopBounds;
 }
 
 SmallVector<utils::IteratorType>
 WinogradOutputTransformOp::getLoopIteratorTypes() {
-  SmallVector<utils::IteratorType> iteratorTypes(6,
+  int64_t valueRank = getValueOperandRank();
+  SmallVector<utils::IteratorType> iteratorTypes(valueRank - 2,
                                                  utils::IteratorType::parallel);
   return iteratorTypes;
 }
@@ -3085,32 +3215,49 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
-  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  Value output = getOutput();
-  auto outputType = cast<ShapedType>(output.getType());
-  ArrayRef<int64_t> outputShape = outputType.getShape();
-  int64_t outputH = outputShape[1];
-  int64_t outputW = outputShape[2];
   int64_t m = getM();
-  IntegerAttr heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
-  IntegerAttr widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
 
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
   auto affineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+  Value mappedOffsetH = builder.create<affine::AffineApplyOp>(
       loc, affineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
-  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[0]));
+  Value mappedOffsetW = builder.create<affine::AffineApplyOp>(
       loc, affineMap,
-      getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[1]));
+  Value mappedSizeH = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueOrCreateConstantIndexOp(builder, loc, sizes[0]));
+  Value mappedSizeW = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueOrCreateConstantIndexOp(builder, loc, sizes[1]));
 
-  resultOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
-  resultSizes.append({sizes[4], heightM, widthM, sizes[5]});
+  ShapedType valueType = getValueOperandType();
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  OpFoldResult offsetH =
+      valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
+  OpFoldResult offsetW =
+      valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
+  OpFoldResult sizeH =
+      valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+  OpFoldResult sizeW =
+      valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+
+  resultOffsets.append({offsets[2], offsetH, offsetW, offsets[3]});
+  resultSizes.append({sizes[2], sizeH, sizeW, sizes[3]});
   return success();
 }
 
+/// Implement tiling for winograd_output_transform
+/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
+/// F). The output of winograd_output_transform is (N, H, W, F) Users can
+/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
+/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
+/// for the sizes of tileH, tileW, N, F for one tile.
 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
     OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes) {
@@ -3120,10 +3267,19 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
   SmallVector<Value> tiledOperands;
   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
+  ShapedType valueType = getValueOperandType();
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  int64_t alphaH = valueShape[0];
+  int64_t alphaW = valueShape[1];
+  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
   sliceOffsets.append(
-      {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
-  sliceSizes.append({sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
-  SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
+      {zeroAttr, zeroAttr, offsets[0], offsets[1], offsets[2], offsets[3]});
+  sliceSizes.append(
+      {alphaHAttr, alphaWAttr, sizes[0], sizes[1], sizes[2], sizes[3]});
+  int64_t valueRank = getValueOperandRank();
+  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
 
@@ -3132,7 +3288,8 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
                                    resultSizes)))
     return failure();
 
-  SmallVector<OpFoldResult> strides(4, oneAttr);
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
   tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
       loc, getOutput(), resultOffsets, resultSizes, strides));
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index c01a5f4b055f1..b65b18699a15a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -490,8 +490,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
   Type elementType = inputType.getElementType();
   auto inputShape = inputType.getShape(); // N, H, W, C
   int64_t inputN = inputShape[0];
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
   int64_t inputC = inputShape[3];
   auto valueType = cast<ShapedType>(retValue.getType());
   auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
@@ -500,11 +498,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
   int64_t alphaH = leftTransform ? m + r - 1 : 1;
   int64_t alphaW = rightTransform ? m + r - 1 : 1;
 
-  if ((inputH != (tileH * m) + (r - 1)) && inputH != 1)
-    return Value();
-  if ((inputW != (tileW * m) + (r - 1)) && inputW != 1)
-    return Value();
-
   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
                        ValueRange args) -> scf::ValueVector {
     Value tileHIter = ivs[0];
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 6bb3fb1423edc..54eb755047f18 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -18,9 +18,9 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
     %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
     %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
@@ -129,9 +129,9 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
     %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
     %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
@@ -235,9 +235,9 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
     %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
     %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
new file mode 100644
index 0000000000000..b1b2818db38fc
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -0,0 +1,392 @@
+// RUN: mlir-opt %s -transform-interpreter --split-input-file | FileCheck %s
+
+func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  return %0 : tensor<6x6x5x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @tile_winograd_filter(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+// CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[C2:.*]] = arith.constant 2 : index
+// CHECK:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:    %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:    %[[C5:.*]] = arith.constant 5 : index
+// CHECK:    %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:    %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
+// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK:      %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
+// CHECK:      %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  return %0 : tensor<6x6x5x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (2, -d0 + 5)>
+// CHECK-LABEL: func.func @tile_winograd_filter(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:     %[[C5:.*]] = arith.constant 5 : index
+// CHECK:     %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C2_1]]
+// CHECK:       %[[C5_2:.*]] = arith.constant 5 : index
+// CHECK:       %[[S3:.*]] = affine.min #[[$MAP0]](%[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, %[[S3]]] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x?xf32>
+// CHECK:       %[[EXTRACTED_SLICE_3:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32>
+// CHECK:       %[[S4:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
+
+// -----
+
+func.func @tile_winograd_filter(%arg0: tensor<2x3x1x5xf32>, %arg1: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
+  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x1x5xf32>) outs(%arg1 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+  return %0 : tensor<6x1x5x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @tile_winograd_filter(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x3x1x5xf32>, %[[ARG1:.*]]: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:     %[[C5:.*]] = arith.constant 5 : index
+// CHECK:     %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
+// CHECK:       %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
+// CHECK:       %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  return %0 : tensor<6x6x2x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop3:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK:   %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK:   %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:   %[[C1_3:.*]] = arith.constant 1 : index
+// CHECK:   %[[S5:.*]] = affine.apply #[[$MAP1]](%[[C1_3]])
+// CHECK:   %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK:   %[[S6:.*]] = affine.apply #[[$MAP1]](%[[C1_4]])
+// CHECK:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x5xf32>
+// CHECK:   %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK:   %[[S7:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  return %0 : tensor<6x6x2x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:   %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK:   %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK:     %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK:     %[[C2_4:.*]] = arith.constant 2 : index
+// CHECK:     %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
+// CHECK:       %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK:       %[[C5:.*]] = arith.constant 5 : index
+// CHECK:       %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
+// CHECK:         %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK:         %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:         %[[C1_8:.*]] = arith.constant 1 : index
+// CHECK:         %[[S7:.*]] = affine.apply #[[$MAP1]](%[[C1_8]])
+// CHECK:         %[[C1_9:.*]] = arith.constant 1 : index
+// CHECK:         %[[S8:.*]] = affine.apply #[[$MAP1]](%[[C1_9]])
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], %[[S8]], 1] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<1x?x?x1xf32>
+// CHECK:         %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32>
+// CHECK:         %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  return %0 : tensor<6x6x2x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (2, -d0 + 5)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
+// CHECK:   %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C2_3:.*]] = arith.constant 2 : index
+// CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
+// CHECK:     %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK:     %[[C2_5:.*]] = arith.constant 2 : index
+// CHECK:     %[[C2_6:.*]] = arith.constant 2 : index
+// CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C2_5]] step %[[C2_6]]
+// CHECK:       %[[C0_7:.*]] = arith.constant 0 : index
+// CHECK:       %[[C5:.*]] = arith.constant 5 : index
+// CHECK:       %[[C2_8:.*]] = arith.constant 2 : index
+// CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_7]] to %[[C5]] step %[[C2_8]]
+// CHECK:         %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG8]])
+// CHECK:         %[[S6:.*]] = affine.apply #[[$MAP1]](%[[ARG2]])
+// CHECK:         %[[S7:.*]] = affine.apply #[[$MAP1]](%[[ARG4]])
+// CHECK:         %[[C2_10:.*]] = arith.constant 2 : index
+// CHECK:         %[[S8:.*]] = affine.apply #[[$MAP2]](%[[C2_10]])
+// CHECK:         %[[C2_11:.*]] = arith.constant 2 : index
+// CHECK:         %[[S9:.*]] = affine.apply #[[$MAP2]](%[[C2_11]])
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S6]], %[[S7]], %[[ARG8]]] [2, %[[S8]], %[[S9]], %[[S5]]] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x?xf32>
+// CHECK:         %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32>
+// CHECK:         %[[S10:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x1x10x5xf32>, %arg1: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x1x10x5xf32>) outs(%arg1 : tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32>
+  return %0 : tensor<1x6x1x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4 + 2)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<2x1x10x5xf32>, %[[ARG1:.*]]: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[C1_0:.*]] = arith.constant 1 : index
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C1]] step %[[C1_0]]
+// CHECK:     %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK:     %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_2]]
+// CHECK:       %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK:       %[[C2_4:.*]] = arith.constant 2 : index
+// CHECK:       %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
+// CHECK:         %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK:         %[[C5:.*]] = arith.constant 5 : index
+// CHECK:         %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
+// CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:           %[[C1_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[S7:.*]] = affine.apply #[[$MAP1]](%[[C1_8]])
+// CHECK:           %[[C1_9:.*]] = arith.constant 1 : index
+// CHECK:           %[[S8:.*]] = affine.apply #[[$MAP1]](%[[C1_9]])
+// CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], 0, %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
+// CHECK:           %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
+// CHECK:           %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x2x2xf32>, %arg1: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x2x2xf32>) outs(%arg1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @tile_winograd_output(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<6x6x2x2x2x2xf32>, %[[ARG1:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:     %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK:     %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK:       %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK:       %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:       %[[C1_3:.*]] = arith.constant 1 : index
+// CHECK:       %[[S5:.*]] = affine.apply #[[$MAP0]](%[[C1_3]])
+// CHECK:       %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK:       %[[S6:.*]] = affine.apply #[[$MAP0]](%[[C1_4]])
+// CHECK:       %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32>
+
+// -----
+
+func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x3x5xf32>, %arg1: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
+  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x3x5xf32>) outs(%arg1 : tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32>
+  return %0 : tensor<3x8x8x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (2, -d0 + 3)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (2, -d0 + 5)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL:  func.func @tile_winograd_output(
+// CHECK-SAME:   %[[ARG0:.*]]: tensor<6x6x2x2x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
+// CHECK:    %[[S0:.*]] = tensor.empty() : tensor<3x8x8x5xf32>
+// CHECK:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK:    %[[C2:.*]] = arith.constant 2 : index
+// CHECK:    %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK:    %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
+// CHECK:      %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:      %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK:      %[[C2_3:.*]] = arith.constant 2 : index
+// CHECK:      %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
+// CHECK:        %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK:        %[[C3:.*]] = arith.constant 3 : index
+// CHECK:        %[[C2_5:.*]] = arith.constant 2 : index
+// CHECK:        %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]]
+// CHECK:          %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK:          %[[C5:.*]] = arith.constant 5 : index
+// CHECK:          %[[C2_7:.*]] = arith.constant 2 : index
+// CHECK:          %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]]
+// CHECK:            %[[C3_8:.*]] = arith.constant 3 : index
+// CHECK:            %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG6]])
+// CHECK:            %[[C5_9:.*]] = arith.constant 5 : index
+// CHECK:            %[[S6:.*]] = affine.min #[[$MAP1]](%[[ARG8]])
+// CHECK:            %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, %[[S5]], %[[S6]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x3x5xf32> to tensor<6x6x2x2x?x?xf32>
+// CHECK:            %[[S7:.*]] = affine.apply #[[$MAP2]](%[[ARG2]])
+// CHECK:            %[[S8:.*]] = affine.apply #[[$MAP2]](%[[ARG4]])
+// CHECK:            %[[C2_10:.*]] = arith.constant 2 : index
+// CHECK:            %[[S9:.*]] = affine.apply #[[$MAP2]](%[[C2_10]])
+// CHECK:            %[[C2_11:.*]] = arith.constant 2 : index
+// CHECK:            %[[S10:.*]] = affine.apply #[[$MAP2]](%[[C2_11]])
+// CHECK:            %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor<?x?x?x?xf32>
+
+// -----
+
+func.func @tile_winograd_output(%arg0 : tensor<6x1x2x1x3x5xf32>, %arg1: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
+  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x1x2x1x3x5xf32>) outs(%arg1 : tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32>
+  return %0 : tensor<3x8x1x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @tile_winograd_output(
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<6x1x2x1x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK:     %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK:     %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:     %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]]
+// CHECK:       %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK:       %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]]
+// CHECK:         %[[C0_5:.*]] = arith.constant 0 : index
+// CHECK:         %[[C5:.*]] = arith.constant 5 : index
+// CHECK:         %[[C1_6:.*]] = arith.constant 1 : index
+// CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]]
+// CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32>
+// CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:           %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK:           %[[S7:.*]] = affine.apply #[[$MAP0]](%[[C1_7]])
+// CHECK:           %[[C1_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[S8:.*]] = affine.apply #[[$MAP0]](%[[C1_8]])
+// CHECK:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
+// CHECK:           %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>



More information about the Mlir-commits mailing list