[Mlir-commits] [mlir] [mlir][linalg] Implement Winograd Conv2D. (PR #94470)

Hsiangkai Wang llvmlistbot at llvm.org
Fri Jun 14 06:39:42 PDT 2024


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/94470

>From 968bf012f2a04f264aa1b04c8b08784c61b249da Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 17 May 2024 14:55:15 +0100
Subject: [PATCH] [mlir][linalg] Implement Winograd Conv2D.

This patch implements the Winograd Conv2D algorithm. It supports several
configurations of Winograd Conv2D, including F(2, 3), F(4, 3) and F(2,
5). These configurations show that the implementation can support
different kernel size (3 and 5) and different output size (2 and 4).
Besides symetric kernel size 3x3 and 5x5, this patch also supports 1x3,
3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |   92 ++
 .../Linalg/TransformOps/LinalgTransformOps.td |   80 ++
 .../Dialect/Linalg/Transforms/Transforms.h    |   19 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  330 ++++++
 .../TransformOps/LinalgTransformOps.cpp       |   52 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |    1 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 1022 +++++++++++++++++
 .../transform-tile-and-winograd-rewrite.mlir  |  166 +++
 .../Linalg/transform-winograd-conv2d.mlir     |   41 +
 .../Linalg/winograd-conv2d-rewrite.mlir       |  105 ++
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir |  169 +++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |   26 +
 12 files changed, 2103 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
 create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
 create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..15b53960005fa 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,96 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+     ["getIterationDomain",
+      "getLoopIteratorTypes",
+      "getResultTilePosition",
+      "getTiledImplementation"]>]> {
+  let summary = "Winograd filter transform operator";
+  let description = [{
+    linalg.winograd_filter_transform transforms the filter of conv2D.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$filter,
+                       AnyRankedTensor:$output,
+                       I64Attr:$output_height,
+                       I64Attr:$output_width,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `output_height` `(` $output_height `)`
+    `output_width` `(` $output_width `)`
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $filter `:` type($filter) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+}
+
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
+  let summary = "Winograd input transform operator";
+  let description = [{
+    linalg.winograd_input_transform transforms the input of conv2D.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$input,
+                       AnyRankedTensor:$output,
+                       I64Attr:$output_height,
+                       I64Attr:$output_width,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `output_height` `(` $output_height `)`
+    `output_width` `(` $output_width `)`
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+}
+
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
+  let summary = "Winograd output transform operator";
+  let description = [{
+    linalg.winograd_output_transform transforms the output of conv2D.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$value,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $value `:` type($value) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+}
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..6742f5024cd8b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,84 @@ def MapCopyToThreadsOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+    "structured.winograd_conv2d",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Use Winograd Conv2D algorithm to compute Conv2D. It will decompose conv2d
+    into three transform operators, i.e., filter transform operator, input
+    transform operator and output transform operator. In addition, use batched
+    matmul to compute the transformed filter and input matrices.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    convolution.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       I64Attr:$m,
+                       I64Attr:$r);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def WinogradConv2DRewriteOp : Op<Transform_Dialect,
+    "structured.winograd_conv2d_rewrite",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Rewrite winograd conv2D operators. It will convert filter, input and
+    output transform operators into a combination of scf, tensor, and linalg
+    equivalent operators. Before applying this transform operator, users
+    need to tile winograd transform operators into supported sizes.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    operator.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..7126b088d19d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,20 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
                                             linalg::BatchMatmulOp op,
                                             bool transposeLHS = true);
 
+/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r);
+FailureOr<Operation *>
+winogradConv2DRewriteFilterTransform(RewriterBase &rewriter,
+                                     linalg::WinogradFilterTransformOp op);
+FailureOr<Operation *>
+winogradConv2DRewriteInputTransform(RewriterBase &rewriter,
+                                    linalg::WinogradInputTransformOp op);
+FailureOr<Operation *>
+winogradConv2DRewriteOutputTransform(RewriterBase &rewriter,
+                                     linalg::WinogradOutputTransformOp op);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
@@ -1692,6 +1706,11 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm.
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r);
+void populateWinogradConv2DRewritePatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..d0e8073d2c8df 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2737,6 +2737,336 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+SmallVector<Range>
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+  SmallVector<Range> loopBounds(4);
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  auto heightAttr = builder.getI64IntegerAttr(getOutputHeight());
+  auto widthAttr = builder.getI64IntegerAttr(getOutputWidth());
+  Value output = getOutput();
+  for (auto dim = 0; dim < 4; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = one;
+  }
+  // Iterate on output domain
+  loopBounds[0].size = heightAttr;
+  loopBounds[1].size = widthAttr;
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradFilterTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(4,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+                               Location loc) {
+  if (auto val = opFoldResult.dyn_cast<Value>()) {
+    return val;
+  } else if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
+    auto intAttr = cast<IntegerAttr>(attr);
+    return builder.create<arith::ConstantOp>(loc, intAttr);
+  }
+  // This should never happen if OpFoldResult is correctly formed
+  return nullptr;
+}
+
+LogicalResult WinogradFilterTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value filter = getFilter();
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = filterH != 1 ? alpha : 1;
+  int64_t alphaW = filterW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  auto context = builder.getContext();
+  auto affineMap = AffineMap::get(
+      1, 0, {builder.getAffineDimExpr(0).floorDiv(m) * alpha}, context);
+
+  Location loc = getLoc();
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+
+  resultOffsets.push_back(mappedOffset1);
+  resultOffsets.push_back(mappedOffset2);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(alphaHAttr);
+  resultSizes.push_back(alphaWAttr);
+  resultSizes.push_back(sizes[2]);
+  resultSizes.push_back(sizes[3]);
+  return success();
+}
+
+FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  Location loc = getLoc();
+  SmallVector<OpFoldResult> strides(4, oneAttr);
+  SmallVector<Value> tiledOperands;
+  tiledOperands.emplace_back(getFilter());
+
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+  SmallVector<Range> loopBounds(4);
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  auto heightAttr = builder.getI64IntegerAttr(getOutputHeight());
+  auto widthAttr = builder.getI64IntegerAttr(getOutputWidth());
+  Value output = getOutput();
+  for (auto dim = 0; dim < 4; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = one;
+  }
+  loopBounds[0].size = heightAttr;
+  loopBounds[1].size = widthAttr;
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(4,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value input = getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  auto context = builder.getContext();
+  auto affineMap = AffineMap::get(
+      1, 0, {builder.getAffineDimExpr(0).floorDiv(m) * alpha}, context);
+
+  Location loc = getLoc();
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+
+  resultOffsets.push_back(mappedOffset1);
+  resultOffsets.push_back(mappedOffset2);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(alphaHAttr);
+  resultSizes.push_back(alphaWAttr);
+  resultSizes.push_back(sizes[2]);
+  resultSizes.push_back(sizes[3]);
+  return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+                                                 ArrayRef<OpFoldResult> offsets,
+                                                 ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value input = getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  Location loc = getLoc();
+  SmallVector<OpFoldResult> strides(4, oneAttr);
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(offsets[0]);
+  sliceOffsets.push_back(offsets[1]);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(sizes[2]);
+  sliceSizes.push_back(alphaHAttr);
+  sliceSizes.push_back(alphaWAttr);
+  sliceSizes.push_back(sizes[3]);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getInput(), sliceOffsets, sliceSizes, strides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+  SmallVector<Range> loopBounds(4);
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value output = getOutput();
+  for (auto dim = 0; dim < 4; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = one;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(4,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  int64_t m = getM();
+  auto mAttr = builder.getI64IntegerAttr(m);
+
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(offsets[1]);
+  resultOffsets.push_back(offsets[2]);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(sizes[0]);
+  resultSizes.push_back(mAttr);
+  resultSizes.push_back(mAttr);
+  resultSizes.push_back(sizes[3]);
+  return success();
+}
+
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value value = getValue();
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = valueH != 1 ? alpha : 1;
+  int64_t alphaW = valueW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+  Location loc = getLoc();
+  SmallVector<OpFoldResult> strides(4, oneAttr);
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  auto context = builder.getContext();
+  auto affineMap = AffineMap::get(
+      1, 0, {builder.getAffineDimExpr(0).floorDiv(m) * alpha}, context);
+
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+
+  sliceOffsets.push_back(mappedOffset1);
+  sliceOffsets.push_back(mappedOffset2);
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(alphaHAttr);
+  sliceSizes.push_back(alphaWAttr);
+  sliceSizes.push_back(sizes[0]);
+  sliceSizes.push_back(sizes[3]);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, value, sliceOffsets, sliceSizes, strides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3a..3627b1a0d46c4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,6 +3480,58 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+            return winogradConv2D(rewriter, op, getM(), getR());
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
+DiagnosedSilenceableFailure transform::WinogradConv2DRewriteOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::WinogradFilterTransformOp op) {
+            return winogradConv2DRewriteFilterTransform(rewriter, op);
+          })
+          .Case([&](linalg::WinogradInputTransformOp op) {
+            return winogradConv2DRewriteInputTransform(rewriter, op);
+          })
+          .Case([&](linalg::WinogradOutputTransformOp op) {
+            return winogradConv2DRewriteOutputTransform(rewriter, op);
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
+  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..d58acb718c3b9
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,1022 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// Implement Winograd Conv2D algorithm. The implementation is based on the
+/// paper: Fast Algorithms for Convolutional Neural Networks
+/// (https://arxiv.org/abs/1509.09308)
+///
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dimension, is
+//
+// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+//
+// g is filter and d is input data. We need to prepare 6 constant
+// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+//
+// The following tables define these constant transformation matrices for
+// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+   -1,     0,   0,
+ 1./2, -1./2, 1./2,
+ 1./2,  1./2, 1./2,
+    0,     0,    1
+};
+
+constexpr float GT_2x2_3x3[] = {
+   -1,  1./2, 1./2, 0,
+    0, -1./2, 1./2, 0,
+    0,  1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+   -1,    0,   1,   0,
+    0,   -1,   1,   0,
+    0,    1,   1,   0,
+    0,   -1,   0,   1
+};
+
+constexpr float B_2x2_3x3[] = {
+   -1,    0,   0,   0,
+    0,   -1,   1,  -1,
+    1,    1,   1,   0,
+    0,    0,   0,   1
+};
+
+constexpr float AT_2x2_3x3[] = {
+    1,    1,   1,   0,
+    0,   -1,   1,   1
+};
+
+constexpr float A_2x2_3x3[] = {
+    1,    0,
+    1,   -1,
+    1,    1,
+    0,    1
+};
+
+constexpr float G_4x4_3x3[] = {
+     1,     0,     0,
+ -1./3,  1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6,  1./3,
+ 1./12,  1./6,  1./3,
+     0,     0,     1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1,  -1./3, -1./3, 1./12, 1./12, 0,
+ 0,   1./3, -1./3, -1./6,  1./6, 0,
+ 0,  -1./3, -1./3,  1./3,  1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4,     0, -5./16,      0, 1./16,     0,
+    0,  1./4,  -1./4, -1./16, 1./16,     0,
+    0, -1./4,  -1./4,  1./16, 1./16,     0,
+    0,  1./4,  -1./8,  -1./4,  1./8,     0,
+    0, -1./4,  -1./8,   1./4,  1./8,     0,
+    0,  1./4,      0, -5./16,     0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+   1./4,      0,     0,     0,     0,      0,
+      0,   1./4, -1./4,  1./4, -1./4,   1./4,
+ -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
+      0, -1./16, 1./16, -1./4,  1./4, -5./16,
+  1./16,  1./16, 1./16,  1./8,  1./8,      0,
+      0,      0,     0,     0,     0,  1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8,  1./4, 1./4,  1./8, 1./8,    0,
+    0, -1./4, 1./4, -1./4, 1./4,    0,
+    0,  1./4, 1./4,  1./2, 1./2,    0,
+    0, -1./4, 1./4,    -1,    1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+  1./8,     0,    0,     0,
+  1./4, -1./4, 1./4, -1./4,
+  1./4,  1./4, 1./4,  1./4,
+  1./8, -1./4, 1./2,    -1,
+  1./8,  1./4, 1./2,     1,
+     0,     0,    0,  1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+     1,     0,      0,      0,      0,
+  1./6, -1./6,   1./6,  -1./6,   1./6,
+ -1./6, -1./6,  -1./6,  -1./6,  -1./6,
+-4./15, 2./15, -1./15,  1./30, -1./60,
+ 1./60, 1./30,  1./15,  2./15,  4./15,
+     0,     0,      0,      0,      1
+};
+
+constexpr float GT_2x2_5x5[] = {
+   1,  1./6, -1./6, -4./15, 1./60, 0,
+   0, -1./6, -1./6,  2./15, 1./30, 0,
+   0,  1./6, -1./6, -1./15, 1./15, 0,
+   0, -1./6, -1./6,  1./30, 2./15, 0,
+   0,  1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
+    0,   1./8,  1./16,  -5./16,   1./8,    0,
+    0,  -1./8, -5./16,  -1./16,   1./8,    0,
+    0,   1./4,  -1./8,   -1./4,   1./8,    0,
+    0,  -1./8,  -1./4,    1./8,   1./4,    0,
+    0,   1./8,  3./16,   -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+   1./8,      0,      0,     0,     0,      0,
+  3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
+  -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
+ -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
+   1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
+      0,      0,      0,     0,     0,   1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+  1./2,  1, 1,  2, 1,    0,
+     0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2,    0,
+    1,   -1,
+    1,    1,
+    2,   -1,
+    1,    2,
+    0, 1./2
+};
+// clang-format on
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
+
+  return rewriter.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               const_vec));
+}
+
+Value extract2DData(RewriterBase &rewriter, Location loc, Value source,
+                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  auto sourceShape = sourceType.getShape();
+  int64_t height = sourceShape[heightIdx];
+  int64_t width = sourceShape[widthIdx];
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+  offsets[outLoopIdx] = outLoopIndex;
+  offsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 4> sizes(4, oneIndex);
+  sizes[heightIdx] = rewriter.getIndexAttr(height);
+  sizes[widthIdx] = rewriter.getIndexAttr(width);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+  SmallVector<int64_t> targetShape(4, 1);
+  targetShape[heightIdx] = height;
+  targetShape[widthIdx] = width;
+
+  auto targetType = RankedTensorType::get(targetShape, elementType);
+  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, targetType, source, offsets, sizes, strides);
+
+  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, extractFilterOp, extractFilterType);
+
+  return extractFilter;
+}
+
+Value insert2DData(RewriterBase &rewriter, Location loc, Value source,
+                   Value dest, Value outLoopIndex, Value inLoopIndex,
+                   int64_t height, int64_t width, int64_t outLoopIdx,
+                   int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  SmallVector<int64_t> sliceShape(4, 1);
+  sliceShape[heightIdx] = height;
+  sliceShape[widthIdx] = width;
+  auto init = rewriter.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(rewriter, loc,
+                                                                 source, init);
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> retOffsets(4, zeroIndex);
+  retOffsets[outLoopIdx] = outLoopIndex;
+  retOffsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 4> retSizes(4, oneIndex);
+  retSizes[heightIdx] = rewriter.getIndexAttr(height);
+  retSizes[widthIdx] = rewriter.getIndexAttr(width);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, result, dest, retOffsets, retSizes, strides);
+
+  return insertSliceOp;
+}
+
+Value collaps2DData(RewriterBase &rewriter, Location loc, Value data) {
+  auto type = cast<ShapedType>(data.getType());
+  auto elementType = type.getElementType();
+  auto shape = type.getShape();
+  auto collapseType = RankedTensorType::get(
+      {shape[0] * shape[1], shape[2], shape[3]}, elementType);
+  SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}};
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
+                                                  reassociation);
+}
+
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+// After the transformation, we get
+//
+// scf.for %f = lo_f to hi_f step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract filter<h x w> from filter<f x h x w x c>
+//     %ret = linalg.matmul G, %extracted
+//     %ret = linalg.matmul %ret, GT
+//     %inserted = insert %ret into filter<h x w x c x f>
+//
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      Value retValue, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GMatrices = {
+          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+      };
+
+  // Map from (m, r) to GT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GTMatrices = {
+          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+      };
+
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+
+  if (filterH != r && filterH != 1)
+    return Value();
+  if (filterW != r && filterW != 1)
+    return Value();
+
+  // Return shape is <H x W x C x F>
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value FIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (F, H, W, C)
+  auto extractFilter =
+      extract2DData(rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+                    /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  Value matmulRetValue = extractFilter;
+  if (leftTransform) {
+    // Get constant transform matrix G
+    auto it = GMatrices.find(key);
+    if (it == GMatrices.end())
+      return Value();
+    const TransformMatrix &GMatrix = it->second;
+
+    retRows = GMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType);
+    // Multiply G x g
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix GT
+    auto it = GTMatrices.find(key);
+    if (it == GTMatrices.end())
+      return Value();
+    const TransformMatrix &GTMatrix = it->second;
+
+    auto matmulType =
+        RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType);
+    // Multiply u = (G x g) x GT
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert (H, W) to (H, W, C, F)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  int64_t retHeight = leftTransform ? m + r - 1 : 1;
+  int64_t retWidth = rightTransform ? m + r - 1 : 1;
+  auto insertSliceOp = insert2DData(
+      rewriter, loc, matmulRetValue, iterArg, FIter, CIter, retHeight, retWidth,
+      /*outLoopIdx=*/3, /*inLoopIdx=*/2, /*heightIdx=*/0, /*widthIdx=*/1);
+
+  rewriter.create<scf::YieldOp>(loc, insertSliceOp);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+// This function transforms the input. The data layout of the input is NHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract input<h x w> from input<n x h x w x c>
+//     %ret = linalg.matmul BT, %extracted
+//     %ret = linalg.matmul %ret, B
+//     %inserted = insert %ret into input<h x w x n x c>
+//
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     Value retValue, int64_t m, int64_t r,
+                     bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to BT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BTMatrices = {
+          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+      };
+
+  // Map from (m, r) to B transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BMatrices = {
+          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+      };
+
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (inputH != alphaH && inputH != 1)
+    return Value();
+  if (inputW != alphaW && inputW != 1)
+    return Value();
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (N, H, W, C)
+  auto extractInput =
+      extract2DData(rewriter, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+                    /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  Value matmulRetValue = extractInput;
+  if (leftTransform) {
+    // Get constant transform matrix BT
+    auto it = BTMatrices.find(key);
+    if (it == BTMatrices.end())
+      return Value();
+    const TransformMatrix &BTMatrix = it->second;
+
+    retRows = BTMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value BT =
+        create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type());
+    // Multiply BT x d
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix B
+    auto it = BMatrices.find(key);
+    if (it == BMatrices.end())
+      return Value();
+    const TransformMatrix &BMatrix = it->second;
+
+    retCols = BMatrix.cols;
+    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+    Value B =
+        create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type());
+    // Multiply v = (BT x d) x B
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert v
+  // Insert (H, W) to (H, W, N, C)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  auto combinedVal = insert2DData(
+      rewriter, loc, matmulRetValue, iterArg, NIter, CIter, retRows, retCols,
+      /*outLoopIdx=*/2, /*inLoopIdx=*/3, /*heightIdx=*/0, /*widthIdx=*/1);
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+// This function generates linalg.batch_matmul to multiply input with filter.
+// linalg.batch_matmul only supports 3-dimension data sets. We can treat H x W
+// data as the 1-dimension data array. That is to convert [H, W, N, C] to
+// [H x W, N, C]. In this way, we can convert 4-dimension input data to
+// 3-dimension representation that is suitable for linalg.batch_matmul.
+//
+// Batched matmul will do the matrix multiply with the reduction on channel.
+//
+// We get
+//
+// %collapsed_input = tensor.collapse_shape %input
+// %collapsed_filter = tensor.collapse_shape %filter
+// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+// %expanded_ret = tensor.expand_shape %ret
+//
+// After this function, we get return value with data layout (H, W, N, F)
+//
+Value matrixMultiply(RewriterBase &rewriter, Location loc,
+                     Value transformedFilter, Value transformedInput) {
+  auto collapseFilter = collaps2DData(rewriter, loc, transformedFilter);
+  auto collapseInput = collaps2DData(rewriter, loc, transformedInput);
+
+  // Batched matrix multiply
+  auto filterType = cast<ShapedType>(transformedFilter.getType());
+  auto filterShape = filterType.getShape();
+  auto inputType = cast<ShapedType>(transformedInput.getType());
+  auto inputElemType = inputType.getElementType();
+  auto inputShape = inputType.getShape();
+
+  auto matmulType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1], inputShape[2], filterShape[3]},
+      inputElemType);
+  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                inputElemType);
+
+  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+      ValueRange{init});
+
+  // Expand matmul result
+  SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}};
+  auto expandType = RankedTensorType::get(
+      {inputShape[0], inputShape[1], inputShape[2], filterShape[3]},
+      inputElemType);
+  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+      loc, expandType, matmulOp.getResult(0), reassociation);
+  return expandOutput;
+}
+
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %f = lo_f to hi_f step 1
+//     %extracted = extract input<h x w> from result<h x w x n x f>
+//     %ret = linalg.matmul AT, %extracted
+//     %ret = linalg.matmul %ret, A
+//     %inserted = insert %ret into ret<n x h x w x f>
+//
+Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
+                      Value output, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to AT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      ATMatrices = {
+          {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+          {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+          {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+      };
+
+  // Map from (m, r) to A transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      AMatrices = {
+          {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+          {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+          {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+      };
+
+  auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
+  auto valueShape = valueType.getShape(); // H, W, N, F
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  int64_t valueN = valueShape[2];
+  int64_t valueF = valueShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (valueH != alphaH && valueH != 1)
+    return Value();
+  if (valueW != alphaW && valueW != 1)
+    return Value();
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value FIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (H, W, N, F)
+  auto extractValue =
+      extract2DData(rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/2,
+                    /*inLoopIdx=*/3, /*heightIdx=*/0, /*widthIdx=*/1);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  int64_t leftScalarFactor = 1;
+  int64_t rightScalarFactor = 1;
+  Value matmulRetValue = extractValue;
+  if (leftTransform) {
+    // Get constant transform matrix AT
+    auto it = ATMatrices.find(key);
+    if (it == ATMatrices.end())
+      return Value();
+    const TransformMatrix &ATMatrix = it->second;
+
+    leftScalarFactor = ATMatrix.scalarFactor;
+    retRows = ATMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType);
+    // Multiply AT x m
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix T
+    auto it = AMatrices.find(key);
+    if (it == AMatrices.end())
+      return Value();
+    const TransformMatrix &AMatrix = it->second;
+
+    rightScalarFactor = AMatrix.scalarFactor;
+    auto matmulType =
+        RankedTensorType::get({retRows, AMatrix.cols}, elementType);
+    retCols = AMatrix.cols;
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType);
+    // Multiply y = (AT x m) x A
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Multiply scalar factor.
+  Value scalarFactor = rewriter.create<arith::ConstantOp>(
+      loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+  auto init =
+      rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType);
+
+  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
+  SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()),
+                                       identityAffineMap, identityAffineMap};
+  auto scalarMatrixOp = rewriter.create<linalg::GenericOp>(
+      loc, matmulType, ValueRange{scalarFactor, matmulRetValue},
+      ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2),
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value scalarVal = args[0];
+        Value matrixVal = args[1];
+        Value result = nestedBuilder.create<arith::MulFOp>(nestedLoc, scalarVal,
+                                                           matrixVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
+
+  // Insert slice y
+  // Insert (H, W) to (N, H, W, F)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  Value combinedVal =
+      insert2DData(rewriter, loc, scalarMatrixOp.getResult(0), iterArg, NIter,
+                   FIter, retRows, retCols,
+                   /*outLoopIdx=*/0,
+                   /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
+                                            linalg::Conv2DNhwcFhwcOp convOp,
+                                            int64_t m, int64_t r) {
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+
+  auto outputType = cast<ShapedType>(output.getType());
+  int64_t outputH = outputType.getShape()[1];
+  int64_t outputW = outputType.getShape()[2];
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputC = inputShape[3];
+
+  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
+  if ((outputH != outputW) && (outputH != 1 && outputW != 1))
+    return failure();
+  if ((filterH != filterW) && (filterH != 1 && filterW != 1))
+    return failure();
+
+  if ((outputH == 1 && filterH != 1) || (outputH != 1 && filterH == 1))
+    return failure();
+  if ((outputW == 1 && filterW != 1) || (outputW != 1 && filterW == 1))
+    return failure();
+
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+      F_2_3, F_4_3, F_2_5};
+
+  TransformMapKeyTy key = {m, r};
+  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+  // If we cannot find the constant transformation matrix, it means we do
+  // not support this configuration yet.
+  if (it == validConfigs.end())
+    return failure();
+
+  // All the criterias are satisfied. We can do Winograd Conv2D.
+  Location loc = convOp.getLoc();
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = outputH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = outputW != 1;
+
+  // Create operator for filter transform
+  Type elementType = filterType.getElementType();
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+  int64_t retHeight = leftTransform ? (outputH / m) * alphaH : 1;
+  int64_t retWidth = rightTransform ? (outputW / m) * alphaW : 1;
+  auto retType = RankedTensorType::get({retHeight, retWidth, filterC, filterF},
+                                       elementType);
+  Value retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
+      loc, retType, filter, retValue, outputH, outputW, m, r);
+
+  // Create operator for input transform
+  retType =
+      RankedTensorType::get({retHeight, retWidth, inputN, inputC}, elementType);
+  retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+  auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
+      loc, retType, input, retValue, outputH, outputW, m, r);
+
+  Value matmulRet =
+      matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
+
+  // create operator for output transform
+  auto transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
+      loc, outputType, matmulRet, output, m, r);
+
+  rewriter.replaceOp(convOp, transformedOutput);
+
+  return transformedOutput.getOperation();
+}
+
+FailureOr<Operation *>
+winogradRewriteFilterTransformHelper(RewriterBase &rewriter,
+                                     linalg::WinogradFilterTransformOp op) {
+  Location loc = op.getLoc();
+  Value filter = op.getFilter();
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  Value transformedFilter =
+      filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedFilter)
+    return failure();
+
+  rewriter.replaceOp(op, transformedFilter);
+
+  return transformedFilter.getDefiningOp();
+}
+
+FailureOr<Operation *>
+winogradRewriteInputTransformHelper(RewriterBase &rewriter,
+                                    linalg::WinogradInputTransformOp op) {
+  Location loc = op.getLoc();
+  Value input = op.getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = inputH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = inputW != 1;
+  Value transformedInput =
+      inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
+                     op.getR(), leftTransform, rightTransform);
+  if (!transformedInput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedInput);
+
+  return transformedInput.getDefiningOp();
+}
+
+FailureOr<Operation *>
+winogradRewriteOutputTransformHelper(RewriterBase &rewriter,
+                                     linalg::WinogradOutputTransformOp op) {
+  Location loc = op.getLoc();
+  Value value = op.getValue();
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = valueH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = valueW != 1;
+  Value transformedOutput =
+      outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedOutput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+class WinogradConv2DRewriteFilterTransform final
+    : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(winogradRewriteFilterTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class WinogradConv2DRewriteInputTransform final
+    : public OpRewritePattern<linalg::WinogradInputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(winogradRewriteInputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class WinogradConv2DRewriteOutputTransform final
+    : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(winogradRewriteOutputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class WinogradConv2DNhwcFhwc final
+    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
+      : OpRewritePattern(context), m(m), r(r) {}
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    Value filter = convOp.getInputs()[1];
+    auto filterType = cast<ShapedType>(filter.getType());
+    auto filterShape = filterType.getShape(); // F, H, W, C
+    int64_t filterH = filterShape[1];
+    int64_t filterW = filterShape[2];
+    Value output = convOp.getOutputs()[0];
+    auto outputType = cast<ShapedType>(output.getType());
+    auto outputShape = outputType.getShape(); // F, H, W, C
+    int64_t outputH = outputShape[1];
+    int64_t outputW = outputShape[2];
+
+    if (filterH != r && filterH != 1 && filterW != r && filterW != 1)
+      return failure();
+
+    if (outputH < m && outputH != 1 && outputW < m && outputW != 1)
+      return failure();
+
+    if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
+      return failure();
+
+    return success();
+  }
+
+private:
+  int64_t m;
+  int64_t r;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r) {
+  return winogradConv2DHelper(rewriter, op, m, r);
+}
+
+FailureOr<Operation *>
+winogradConv2DRewriteFilterTransform(RewriterBase &rewriter,
+                                     linalg::WinogradFilterTransformOp op) {
+  return winogradRewriteFilterTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+winogradConv2DRewriteInputTransform(RewriterBase &rewriter,
+                                    linalg::WinogradInputTransformOp op) {
+  return winogradRewriteInputTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+winogradConv2DRewriteOutputTransform(RewriterBase &rewriter,
+                                     linalg::WinogradOutputTransformOp op) {
+  return winogradRewriteOutputTransformHelper(rewriter, op);
+}
+
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
+}
+
+void populateWinogradConv2DRewritePatterns(RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<WinogradConv2DRewriteFilterTransform>(context);
+  patterns.insert<WinogradConv2DRewriteInputTransform>(context);
+  patterns.insert<WinogradConv2DRewriteOutputTransform>(context);
+}
+
+} // end namespace linalg
+} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
new file mode 100644
index 0000000000000..b7f9429625958
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -0,0 +1,166 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<2x8x8x2xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x8x8x2xf32>
+  %2 = tensor.empty() : tensor<12x12x5x2xf32>
+  %3 = linalg.winograd_filter_transform output_height(8) output_width(8) m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<12x12x5x2xf32>) -> tensor<12x12x5x2xf32>
+  %4 = tensor.empty() : tensor<12x12x2x5xf32>
+  %5 = linalg.winograd_input_transform output_height(8) output_width(8) m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%4 : tensor<12x12x2x5xf32>) -> tensor<12x12x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<12x12x5x2xf32> into tensor<144x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2], [3]] : tensor<12x12x2x5xf32> into tensor<144x2x5xf32>
+  %6 = tensor.empty() : tensor<144x2x2xf32>
+  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%6 : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+  %expanded = tensor.expand_shape %7 [[0, 1], [2], [3]] output_shape [12, 12, 2, 2] : tensor<144x2x2xf32> into tensor<12x12x2x2xf32>
+  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<12x12x2x2xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %8 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [4, 4, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [4, 4, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 4, 4, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %6 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %1 : (!transform.any_op) -> !transform.any_op
+    %7 = transform.structured.winograd_conv2d_rewrite %6 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.winograd_conv2d_rewrite %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.winograd_conv2d_rewrite %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ((d0 floordiv 4) * 6)>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-DAG:    %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:    %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:    %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:    %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:    %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:    %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:    %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:    %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:    %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<12x12x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = tensor.empty() : tensor<12x12x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG4:.*]] = %[[S3]]) -> (tensor<12x12x5x2xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<12x12x5x2xf32>) {
+// CHECK-NEXT:       %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S2]][%[[S13]], %[[S14]], 0, 0] [6, 6, 5, 2] [1, 1, 1, 1] : tensor<12x12x5x2xf32> to tensor<6x6x5x2xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:         %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_7]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_8]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S19]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[INSERTED_SLICE_9]] into %[[ARG10]][0, 0, %[[ARG9]], %[[ARG7]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_10]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S18]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[S16:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S17:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[S16]], %[[S17]], 0, 0] [6, 6, 5, 2] [1, 1, 1, 1] : tensor<6x6x5x2xf32> into tensor<12x12x5x2xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<12x12x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<12x12x5x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[S5:.*]] = tensor.empty() : tensor<12x12x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<12x12x2x5xf32>
+// CHECK-NEXT:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<12x12x2x5xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<12x12x2x5xf32>) {
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[ARG3]], %[[ARG5]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
+// CHECK-NEXT:       %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][%[[S13]], %[[S14]], 0, 0] [6, 6, 2, 5] [1, 1, 1, 1] : tensor<12x12x2x5xf32> to tensor<6x6x2x5xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:         %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE]]_9 : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S19]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x2x5xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S18]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[S16:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S17:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[S16]], %[[S17]], 0, 0] [6, 6, 2, 5] [1, 1, 1, 1] : tensor<6x6x2x5xf32> into tensor<12x12x2x5xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<12x12x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<12x12x2x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %4 {{\[}}[0, 1], [2], [3]] : tensor<12x12x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S7]] {{\[}}[0, 1], [2], [3]] : tensor<12x12x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT:   %[[S8:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[S9:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S8]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2], [3]] output_shape [12, 12, 2, 2] : tensor<144x2x2xf32> into tensor<12x12x2x2xf32>
+// CHECK-NEXT:   %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S11:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:       %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %expanded[%[[S13]], %[[S14]], 0, 0] [6, 6, 2, 2] [1, 1, 1, 1] : tensor<12x12x2x2xf32> to tensor<6x6x2x2xf32>
+// CHECK-NEXT:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S1]][0, %[[ARG3]], %[[ARG5]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:         %[[S16:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x2x2xf32> to tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S17:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:           %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]], #[[$MAP4]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S20]] : f32, tensor<4x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) {
+// CHECK-NEXT:           ^bb0(%[[IN:.*]]: f32, %[[IN_12:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:             %[[S24:.*]] = arith.mulf %[[IN]], %[[IN_12]] : f32
+// CHECK-NEXT:             linalg.yield %[[S24]] : f32
+// CHECK-NEXT:           } -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_11]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S16]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[ARG3]], %[[ARG5]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S11]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
new file mode 100644
index 0000000000000..231f87459f230
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<2x8x8x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x8x8x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %2 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<12x12x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform output_height(8) output_width(8) m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<12x12x5x2xf32>) -> tensor<12x12x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<12x12x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform output_height(8) output_width(8) m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<12x12x2x5xf32>) -> tensor<12x12x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<12x12x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<12x12x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [12, 12, 2, 2] : tensor<144x2x2xf32> into tensor<12x12x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<12x12x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
new file mode 100644
index 0000000000000..691be1fe2b288
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d-rewrite | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = tensor.empty() : tensor<6x6x5x2xf32>
+  %3 = linalg.winograd_filter_transform output_height(4) output_width(4) m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %4 = tensor.empty() : tensor<6x6x2x5xf32>
+  %5 = linalg.winograd_input_transform output_height(4) output_width(4) m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%4 : tensor<6x6x2x5xf32>) -> tensor<6x6x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+  %6 = tensor.empty() : tensor<36x2x2xf32>
+  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %expanded = tensor.expand_shape %7 [[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %8 : tensor<2x4x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:   %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:   %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:   %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:   %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S4]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x2x5xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S1]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x2x2xf32> to tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) {
+// CHECK-NEXT:      ^bb0(%[[IN:.*]]: f32, %[[IN_9:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:        %[[S17:.*]] = arith.mulf %[[IN]], %[[IN_9]] : f32
+// CHECK-NEXT:        linalg.yield %[[S17]] : f32
+// CHECK-NEXT:      } -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S16:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:}
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
new file mode 100644
index 0000000000000..0f827f16b6b81
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -0,0 +1,169 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
+
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %2 : tensor<2x4x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform output_height(4) output_width(4) m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform output_height(4) output_width(4) m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x2x5xf32>) -> tensor<6x6x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_2x2_3x3(%arg0: tensor<2x4x4x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = tensor.empty() : tensor<2x2x2x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x2x2x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x4x4x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %2 : tensor<2x2x2x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_2x2_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x4x4x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<4x4x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform output_height(2) output_width(2) m(2) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<4x4x5x2xf32>) -> tensor<4x4x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<4x4x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform output_height(2) output_width(2) m(2) r(3) ins(%[[ARG0]] : tensor<2x4x4x5xf32>) outs(%[[S4]] : tensor<4x4x2x5xf32>) -> tensor<4x4x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<4x4x5x2xf32> into tensor<16x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<4x4x2x5xf32> into tensor<16x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<16x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<16x2x5xf32>, tensor<16x5x2xf32>) outs(%[[S6]] : tensor<16x2x2xf32>) -> tensor<16x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [4, 4, 2, 2] : tensor<16x2x2xf32> into tensor<4x4x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(3) ins(%[[EXPANDED]] : tensor<4x4x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = tensor.empty() : tensor<2x2x2x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x2x2x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %2 : tensor<2x2x2x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_2x2_5x5
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform output_height(2) output_width(2) m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform output_height(2) output_width(2) m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x2x5xf32>) -> tensor<6x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x1x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x1x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+  return %2 : tensor<2x1x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_1x4_1x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform output_height(1) output_width(4) m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform output_height(1) output_width(4) m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x2x5xf32>) -> tensor<1x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x1x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x1x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %2 : tensor<2x4x1x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_4x1_3x1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform output_height(4) output_width(1) m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform output_height(4) output_width(1) m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x2x5xf32>) -> tensor<6x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT: }
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4892fa2f99a7c..d625146a669b4 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,6 +123,14 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
+  Option<bool> testWinogradConv2D{
+      *this, "test-winograd-conv2d",
+      llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
+      llvm::cl::init(false)};
+  Option<bool> testWinogradConv2DRewrite{
+      *this, "test-winograd-conv2d-rewrite",
+      llvm::cl::desc("Test rewrite Winograd conv2d ops"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -207,6 +215,20 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyWinogradConv2D(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
+  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/3);
+  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
+static void applyWinogradConv2DRewrite(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateWinogradConv2DRewritePatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -231,6 +253,10 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
+  if (testWinogradConv2D)
+    return applyWinogradConv2D(getOperation());
+  if (testWinogradConv2DRewrite)
+    return applyWinogradConv2DRewrite(getOperation());
 }
 
 namespace mlir {



More information about the Mlir-commits mailing list