[Mlir-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)

Hsiangkai Wang llvmlistbot at llvm.org
Wed Jul 17 07:12:07 PDT 2024


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/96183

>From 4240341b4f06f1b77f63b0f619cae3804d88eb68 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:24:07 +0100
Subject: [PATCH 01/22] [mlir][linalg] Implement Conv2D using Winograd Conv2D
 algorithm

Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       | 114 +++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  78 +++++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 321 ++++++++++++++++++
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 248 ++++++++++++++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  13 +
 7 files changed, 779 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..de1097b6ac27b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+  let summary = "Winograd filter transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of filter
+    transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$filter,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $filter `:` type($filter) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+  let summary = "Winograd input transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of input
+    transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$input,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+  let summary = "Winograd output transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of output
+    transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$value,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $value `:` type($value) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..835aeaf2ffed3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..7bf2a5bca037f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,6 +2734,84 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradFilterTransformOp::verify() {
+  auto filterType = cast<ShapedType>(getFilter().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto filterElemType = filterType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (filterElemType != outputElemType) {
+    return emitOpError() << "expected element type of input " << filterElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned filterRank = filterType.getRank();
+  if (filterRank != 4)
+    return emitOpError() << "expected rank of input is 4";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 6)
+    return emitOpError() << "expected rank of output is 6";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradInputTransformOp::verify() {
+  auto inputType = cast<ShapedType>(getInput().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto inputElemType = inputType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (inputElemType != outputElemType) {
+    return emitOpError() << "expected element type of input " << inputElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned inputRank = inputType.getRank();
+  if (inputRank != 4)
+    return emitOpError() << "expected rank of input is 4";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 6)
+    return emitOpError() << "expected rank of output is 6";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+  auto valueType = cast<ShapedType>(getValue().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto valueElemType = valueType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (valueElemType != outputElemType) {
+    return emitOpError() << "expected element type of value " << valueElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned valueRank = valueType.getRank();
+  if (valueRank != 6)
+    return emitOpError() << "expected rank of input is 6";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 4)
+    return emitOpError() << "expected rank of output is 4";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
+  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..86e834d51f2fc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,321 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
+  auto type = cast<ShapedType>(data.getType());
+  auto elementType = type.getElementType();
+  auto shape = type.getShape();
+  auto collapseType = RankedTensorType::get(
+      {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
+      elementType);
+  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
+                                                  reassociation);
+}
+
+// This function generates linalg.batch_matmul to multiply input with filter.
+// linalg.batch_matmul only supports 3-dimension data sets. We can treat
+// tileH x tileW x H x W data as the 1-dimension data array. That is to convert
+// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we
+// can convert 6-dimension input data to 3-dimension representation that is
+// suitable for linalg.batch_matmul.
+//
+// Batched matmul will do the matrix multiply with the reduction on channel.
+//
+// We get
+//
+// %collapsed_input = tensor.collapse_shape %input
+// %collapsed_filter = tensor.collapse_shape %filter
+// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+// %expanded_ret = tensor.expand_shape %ret
+//
+// After this function, we get return value with data layout
+// (tileH, tileW, H, W, N, F).
+Value matrixMultiply(RewriterBase &rewriter, Location loc,
+                     Value transformedFilter, Value transformedInput) {
+  auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
+  auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
+
+  // Batched matrix multiply
+  auto filterType = cast<ShapedType>(transformedFilter.getType());
+  auto filterShape = filterType.getShape();
+  auto inputType = cast<ShapedType>(transformedInput.getType());
+  auto inputElemType = inputType.getElementType();
+  auto inputShape = inputType.getShape();
+
+  auto matmulType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
+       inputShape[4], filterShape[5]},
+      inputElemType);
+  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                inputElemType);
+
+  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+      ValueRange{init});
+
+  // Expand matmul result
+  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+  auto expandType =
+      RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
+                             inputShape[3], inputShape[4], filterShape[5]},
+                            inputElemType);
+  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+      loc, expandType, matmulOp.getResult(0), reassociation);
+  return expandOutput;
+}
+
+Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
+                            RankedTensorType alignedType) {
+  Value alignedInput = rewriter.create<tensor::EmptyOp>(
+      loc, alignedType.getShape(), alignedType.getElementType());
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  SmallVector<OpFoldResult, 4> sizes;
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[0]));
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[1]));
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[2]));
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[3]));
+
+  return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput,
+                                                offsets, sizes, strides);
+}
+
+Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
+                               Value value, RankedTensorType extractedType) {
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto extractedShape = extractedType.getShape();
+  SmallVector<OpFoldResult, 4> sizes;
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0]));
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1]));
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2]));
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3]));
+
+  return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
+                                                 offsets, sizes, strides);
+}
+
+bool hasAllOneValues(DenseIntElementsAttr attr) {
+  return llvm::all_of(
+      attr, [](const APInt &element) { return element.getSExtValue() == 1; });
+}
+
+FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
+                                            linalg::Conv2DNhwcFhwcOp convOp,
+                                            int64_t m, int64_t r) {
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+  auto inputType = cast<ShapedType>(input.getType());
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto outputType = cast<ShapedType>(output.getType());
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  if (!hasAllOneValues(convOp.getStrides()))
+    return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
+
+  auto filterShape = filterType.getShape();
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+  auto inputShape = inputType.getShape();
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  auto outputShape = outputType.getShape();
+  int64_t outputN = outputShape[0];
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int64_t outputF = outputShape[3];
+
+  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
+  bool isSupportedFilter = false;
+  if (filterH == filterW && filterH == r)
+    isSupportedFilter = true;
+  if (filterH == r && filterW == 1)
+    isSupportedFilter = true;
+  if (filterH == 1 && filterW == r)
+    isSupportedFilter = true;
+
+  if (!isSupportedFilter)
+    return rewriter.notifyMatchFailure(
+        convOp, "only support filter (r x r), (r x 1) or (1 x r)");
+
+  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5)
+  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+      F_2_3, F_4_3, F_2_5};
+
+  TransformMapKeyTy key = {m, r};
+  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+  // If we cannot find the constant transformation matrix, it means we do
+  // not support this configuration yet.
+  if (it == validConfigs.end())
+    return failure();
+
+  // All the criterias are satisfied. We can do Winograd Conv2D.
+  Location loc = convOp.getLoc();
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  int64_t heightM = leftTransform ? m : 1;
+  int64_t widthM = rightTransform ? m : 1;
+  int64_t heightR = leftTransform ? r : 1;
+  int64_t widthR = rightTransform ? r : 1;
+
+  // --- Create operator for filter transform ---
+  Type elementType = filterType.getElementType();
+  int64_t alphaH = heightM + heightR - 1;
+  int64_t alphaW = widthM + widthR - 1;
+  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
+  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
+  auto retType = RankedTensorType::get(
+      {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType);
+  Value retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
+      loc, retType, filter, retValue, m, r);
+
+  // --- Create operator for input transform ---
+
+  // When input size - (r - 1) is not aligned with output tile size, we need to
+  // pad the input data to create the full tiles as tiling.
+  int64_t alignedInputH = tileH * heightM + (heightR - 1);
+  int64_t alignedInputW = tileW * widthM + (widthR - 1);
+  if (alignedInputH != inputH || alignedInputW != inputW) {
+    auto alignedInputType = RankedTensorType::get(
+        {inputN, alignedInputH, alignedInputW, inputC}, elementType);
+    input = insertToAlignedTensor(rewriter, loc, input, alignedInputType);
+  }
+
+  retType = RankedTensorType::get(
+      {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType);
+  retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+  auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
+      loc, retType, input, retValue, m, r);
+
+  Value matmulRet =
+      matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
+
+  // --- Create operator for output transform ---
+
+  // When output size is not aligned with output tile size, we need to pad the
+  // output buffer to insert the full tiles after tiling.
+  int64_t alignedOutputH = tileH * heightM;
+  int64_t alignedOutputW = tileW * widthM;
+  bool isOutputUnaligned =
+      ((alignedOutputH != outputH) || (alignedOutputW != outputW));
+  if (isOutputUnaligned) {
+    auto alignedOutputType = RankedTensorType::get(
+        {outputN, alignedOutputH, alignedOutputW, outputF}, elementType);
+    output = insertToAlignedTensor(rewriter, loc, output, alignedOutputType);
+    outputType = alignedOutputType;
+  }
+
+  Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
+      loc, outputType, matmulRet, output, m, r);
+
+  // When output size is not aligned with output tile size, extract the
+  // value from the padded buffer.
+  if (isOutputUnaligned) {
+    transformedOutput = extractFromAlignedTensor(
+        rewriter, loc, transformedOutput,
+        RankedTensorType::get({outputN, outputH, outputW, outputF},
+                              elementType));
+  }
+
+  rewriter.replaceOp(convOp, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+class WinogradConv2DNhwcFhwc final
+    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
+      : OpRewritePattern(context), m(m), r(r) {}
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
+      return failure();
+
+    return success();
+  }
+
+private:
+  int64_t m;
+  int64_t r;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
+}
+
+} // end namespace linalg
+} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
new file mode 100644
index 0000000000000..6cca3c602d4c0
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
+
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %2 : tensor<2x4x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = tensor.empty() : tensor<2x2x2x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x2x2x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %2 : tensor<2x2x2x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_2x2_5x5
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x1x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x1x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+  return %2 : tensor<2x1x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_1x4_1x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x1x1x6x2x5xf32>) -> tensor<1x1x1x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x1x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x1x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %2 : tensor<2x4x1x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_4x1_3x1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<1x1x6x1x2x5xf32>) -> tensor<1x1x6x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<2x8x8x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x8x8x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %2 : tensor<2x8x8x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_aligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = tensor.empty() : tensor<2x9x9x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x9x9x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+  return %2 : tensor<2x9x9x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x9x9x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:  %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
+// CHECK-NEXT:  %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
+// CHECK-NEXT:  %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %2 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_1
+// CHECK: linalg.conv_2d_nhwc_fhwc
+
+// -----
+
+func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %2 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_2
+// CHECK: linalg.conv_2d_nhwc_fhwc
+
+// -----
+
+func.func @conv2d_unsupported_3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<2x3x3x5xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_3
+// CHECK: linalg.conv_2d_nhwc_fhwc
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4892fa2f99a7c..12cb46a5968f1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,6 +123,10 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
+  Option<bool> testWinogradConv2D{
+      *this, "test-winograd-conv2d",
+      llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -207,6 +211,13 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyWinogradConv2D(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
+  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -231,6 +242,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
+  if (testWinogradConv2D)
+    return applyWinogradConv2D(getOperation());
 }
 
 namespace mlir {

>From 374b0d5b83ce080bea690199380e270a36ad1c52 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:49:08 +0100
Subject: [PATCH 02/22] [mlir][linalg] Add transform operator for Winograd
 Conv2D algorithm

Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
---
 .../Linalg/TransformOps/LinalgTransformOps.td | 51 +++++++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  7 ++
 .../TransformOps/LinalgTransformOps.cpp       | 25 ++++++
 .../Linalg/Transforms/WinogradConv2D.cpp      |  6 ++
 .../Linalg/transform-winograd-conv2d.mlir     | 88 +++++++++++++++++++
 5 files changed, 177 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..68d0f713caad4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+    "structured.winograd_conv2d",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    convolution.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       I64Attr:$m,
+                       I64Attr:$r);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 835aeaf2ffed3..da107b66257a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
                                             linalg::BatchMatmulOp op,
                                             bool transposeLHS = true);
 
+/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
+/// F(m x m, r x r). m is the dimension size of output and r is the dimension
+/// size of filter.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..d051b29e1f06f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+            return winogradConv2D(rewriter, op, getM(), getR());
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 86e834d51f2fc..d1f4be8bbf29a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -311,6 +311,12 @@ class WinogradConv2DNhwcFhwc final
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r) {
+  return winogradConv2DHelper(rewriter, op, m, r);
+}
+
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
new file mode 100644
index 0000000000000..1e74fea5a1c31
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<2x8x8x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x8x8x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %2 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = tensor.empty() : tensor<2x9x9x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x9x9x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+  return %2 : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x9x9x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:   %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
+// CHECK-NEXT:   %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
+// CHECK-NEXT:   %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK-NEXT:   %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }

>From 24c4f957ae673c2955fc0674f91e488813d59350 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 17:39:49 +0100
Subject: [PATCH 03/22] [mlir][linalg] Decompose winograd operators

Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../Dialect/Linalg/Transforms/Transforms.h    |   3 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 773 ++++++++++++++++++
 .../Linalg/winograd-conv2d-rewrite.mlir       | 105 +++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  11 +
 4 files changed, 892 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index da107b66257a5..bb7ec590faad0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1703,6 +1703,9 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r);
 
+/// Patterns to decompose Winograd operators.
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index d1f4be8bbf29a..d245723c85646 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -12,7 +12,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -23,6 +26,156 @@ namespace linalg {
 
 namespace {
 
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dimension, is
+//
+// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+//
+// g is filter and d is input data. We need to prepare 6 constant
+// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+//
+// The following tables define these constant transformation matrices for
+// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+   -1,     0,   0,
+ 1./2, -1./2, 1./2,
+ 1./2,  1./2, 1./2,
+    0,     0,    1
+};
+
+constexpr float GT_2x2_3x3[] = {
+   -1,  1./2, 1./2, 0,
+    0, -1./2, 1./2, 0,
+    0,  1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+   -1,    0,   1,   0,
+    0,   -1,   1,   0,
+    0,    1,   1,   0,
+    0,   -1,   0,   1
+};
+
+constexpr float B_2x2_3x3[] = {
+   -1,    0,   0,   0,
+    0,   -1,   1,  -1,
+    1,    1,   1,   0,
+    0,    0,   0,   1
+};
+
+constexpr float AT_2x2_3x3[] = {
+    1,    1,   1,   0,
+    0,   -1,   1,   1
+};
+
+constexpr float A_2x2_3x3[] = {
+    1,    0,
+    1,   -1,
+    1,    1,
+    0,    1
+};
+
+constexpr float G_4x4_3x3[] = {
+     1,     0,     0,
+ -1./3,  1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6,  1./3,
+ 1./12,  1./6,  1./3,
+     0,     0,     1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1,  -1./3, -1./3, 1./12, 1./12, 0,
+ 0,   1./3, -1./3, -1./6,  1./6, 0,
+ 0,  -1./3, -1./3,  1./3,  1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4,     0, -5./16,      0, 1./16,     0,
+    0,  1./4,  -1./4, -1./16, 1./16,     0,
+    0, -1./4,  -1./4,  1./16, 1./16,     0,
+    0,  1./4,  -1./8,  -1./4,  1./8,     0,
+    0, -1./4,  -1./8,   1./4,  1./8,     0,
+    0,  1./4,      0, -5./16,     0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+   1./4,      0,     0,     0,     0,      0,
+      0,   1./4, -1./4,  1./4, -1./4,   1./4,
+ -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
+      0, -1./16, 1./16, -1./4,  1./4, -5./16,
+  1./16,  1./16, 1./16,  1./8,  1./8,      0,
+      0,      0,     0,     0,     0,  1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8,  1./4, 1./4,  1./8, 1./8,    0,
+    0, -1./4, 1./4, -1./4, 1./4,    0,
+    0,  1./4, 1./4,  1./2, 1./2,    0,
+    0, -1./4, 1./4,    -1,    1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+  1./8,     0,    0,     0,
+  1./4, -1./4, 1./4, -1./4,
+  1./4,  1./4, 1./4,  1./4,
+  1./8, -1./4, 1./2,    -1,
+  1./8,  1./4, 1./2,     1,
+     0,     0,    0,  1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+     1,     0,      0,      0,      0,
+  1./6, -1./6,   1./6,  -1./6,   1./6,
+ -1./6, -1./6,  -1./6,  -1./6,  -1./6,
+-4./15, 2./15, -1./15,  1./30, -1./60,
+ 1./60, 1./30,  1./15,  2./15,  4./15,
+     0,     0,      0,      0,      1
+};
+
+constexpr float GT_2x2_5x5[] = {
+   1,  1./6, -1./6, -4./15, 1./60, 0,
+   0, -1./6, -1./6,  2./15, 1./30, 0,
+   0,  1./6, -1./6, -1./15, 1./15, 0,
+   0, -1./6, -1./6,  1./30, 2./15, 0,
+   0,  1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
+    0,   1./8,  1./16,  -5./16,   1./8,    0,
+    0,  -1./8, -5./16,  -1./16,   1./8,    0,
+    0,   1./4,  -1./8,   -1./4,   1./8,    0,
+    0,  -1./8,  -1./4,    1./8,   1./4,    0,
+    0,   1./8,  3./16,   -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+   1./8,      0,      0,     0,     0,      0,
+  3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
+  -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
+ -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
+   1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
+      0,      0,      0,     0,     0,   1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+  1./2,  1, 1,  2, 1,    0,
+     0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2,    0,
+    1,   -1,
+    1,    1,
+    2,   -1,
+    1,    2,
+    0, 1./2
+};
+// clang-format on
+
 using TransformMapKeyTy = std::pair<int, int>;
 
 // We use F(m, r) to define the size of minimal filtering algorithms.
@@ -36,6 +189,92 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
+
+  return rewriter.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               const_vec));
+}
+
+Value extract2DData(RewriterBase &rewriter, Location loc, Value source,
+                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+                    int64_t srcSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  auto sourceShape = sourceType.getShape();
+  int64_t height = sourceShape[heightIdx];
+  int64_t width = sourceShape[widthIdx];
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
+  offsets[outLoopIdx] = outLoopIndex;
+  offsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
+  sizes[heightIdx] = rewriter.getIndexAttr(height);
+  sizes[widthIdx] = rewriter.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
+  SmallVector<int64_t> targetShape(srcSize, 1);
+  targetShape[heightIdx] = height;
+  targetShape[widthIdx] = width;
+
+  auto targetType = RankedTensorType::get(targetShape, elementType);
+  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, targetType, source, offsets, sizes, strides);
+
+  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, extractFilterOp, extractFilterType);
+
+  return extractFilter;
+}
+
+Value insert2DData(RewriterBase &rewriter, Location loc, Value source,
+                   Value dest, Value outLoopIndex, Value inLoopIndex,
+                   int64_t height, int64_t width, int64_t outLoopIdx,
+                   int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+                   int64_t destSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  SmallVector<int64_t> sliceShape(destSize, 1);
+  sliceShape[heightIdx] = height;
+  sliceShape[widthIdx] = width;
+  auto init = rewriter.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(rewriter, loc,
+                                                                 source, init);
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
+  retOffsets[outLoopIdx] = outLoopIndex;
+  retOffsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
+  retSizes[heightIdx] = rewriter.getIndexAttr(height);
+  retSizes[widthIdx] = rewriter.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
+
+  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, result, dest, retOffsets, retSizes, strides);
+
+  return insertSliceOp;
+}
+
 Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
   auto type = cast<ShapedType>(data.getType());
   auto elementType = type.getElementType();
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
                                                   reassociation);
 }
 
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+// After the transformation, we get
+//
+// scf.for %f = lo_f to hi_f step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract filter<h x w> from filter<f x h x w x c>
+//     %ret = linalg.matmul G, %extracted
+//     %ret = linalg.matmul %ret, GT
+//     %inserted = insert %ret into filter<tile_h x tile_w x h x w x c x f>
+//
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      Value retValue, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GMatrices = {
+          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+      };
+
+  // Map from (m, r) to GT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GTMatrices = {
+          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+      };
+
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+
+  if (filterH != r && filterH != 1)
+    return Value();
+  if (filterW != r && filterW != 1)
+    return Value();
+
+  // Return shape is <H x W x C x F>
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value FIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (F, H, W, C)
+  auto extractFilter = extract2DData(
+      rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+      /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  Value matmulRetValue = extractFilter;
+  if (leftTransform) {
+    // Get constant transform matrix G
+    auto it = GMatrices.find(key);
+    if (it == GMatrices.end())
+      return Value();
+    const TransformMatrix &GMatrix = it->second;
+
+    retRows = GMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType);
+    // Multiply G x g
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix GT
+    auto it = GTMatrices.find(key);
+    if (it == GTMatrices.end())
+      return Value();
+    const TransformMatrix &GTMatrix = it->second;
+
+    auto matmulType =
+        RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType);
+    // Multiply u = (G x g) x GT
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert (H, W) to (1, 1, H, W, C, F)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  int64_t retHeight = leftTransform ? m + r - 1 : 1;
+  int64_t retWidth = rightTransform ? m + r - 1 : 1;
+  auto insertSliceOp = insert2DData(
+      rewriter, loc, matmulRetValue, iterArg, FIter, CIter, retHeight, retWidth,
+      /*outLoopIdx=*/5, /*inLoopIdx=*/4, /*heightIdx=*/2, /*widthIdx=*/3,
+      /*destSize=*/6);
+
+  rewriter.create<scf::YieldOp>(loc, insertSliceOp);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+// This function transforms the input. The data layout of the input is NHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract input<h x w> from input<n x h x w x c>
+//     %ret = linalg.matmul BT, %extracted
+//     %ret = linalg.matmul %ret, B
+//     %inserted = insert %ret into input<h x w x n x c>
+//
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     Value retValue, int64_t m, int64_t r,
+                     bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to BT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BTMatrices = {
+          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+      };
+
+  // Map from (m, r) to B transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BMatrices = {
+          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+      };
+
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (inputH != alphaH && inputH != 1)
+    return Value();
+  if (inputW != alphaW && inputW != 1)
+    return Value();
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (N, H, W, C)
+  auto extractInput = extract2DData(
+      rewriter, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+      /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  Value matmulRetValue = extractInput;
+  if (leftTransform) {
+    // Get constant transform matrix BT
+    auto it = BTMatrices.find(key);
+    if (it == BTMatrices.end())
+      return Value();
+    const TransformMatrix &BTMatrix = it->second;
+
+    retRows = BTMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value BT =
+        create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type());
+    // Multiply BT x d
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix B
+    auto it = BMatrices.find(key);
+    if (it == BMatrices.end())
+      return Value();
+    const TransformMatrix &BMatrix = it->second;
+
+    retCols = BMatrix.cols;
+    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+    Value B =
+        create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type());
+    // Multiply v = (BT x d) x B
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert v
+  // Insert (H, W) to (1, 1, H, W, N, C)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  auto combinedVal = insert2DData(
+      rewriter, loc, matmulRetValue, iterArg, NIter, CIter, retRows, retCols,
+      /*outLoopIdx=*/4, /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3,
+      /*destSize=*/6);
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
 // This function generates linalg.batch_matmul to multiply input with filter.
 // linalg.batch_matmul only supports 3-dimension data sets. We can treat
 // tileH x tileW x H x W data as the 1-dimension data array. That is to convert
@@ -100,6 +594,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc,
   return expandOutput;
 }
 
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %f = lo_f to hi_f step 1
+//     %extracted = extract input<h x w> from result<h x w x n x f>
+//     %ret = linalg.matmul AT, %extracted
+//     %ret = linalg.matmul %ret, A
+//     %inserted = insert %ret into ret<n x h x w x f>
+//
+Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
+                      Value output, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to AT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      ATMatrices = {
+          {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+          {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+          {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+      };
+
+  // Map from (m, r) to A transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      AMatrices = {
+          {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+          {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+          {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+      };
+
+  auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
+  auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F
+  int64_t valueH = valueShape[2];
+  int64_t valueW = valueShape[3];
+  int64_t valueN = valueShape[4];
+  int64_t valueF = valueShape[5];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (valueH != alphaH && valueH != 1)
+    return Value();
+  if (valueW != alphaW && valueW != 1)
+    return Value();
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value FIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (1, 1, H, W, N, F)
+  auto extractValue = extract2DData(
+      rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/4,
+      /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, /*srcSize=*/6);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  int64_t leftScalarFactor = 1;
+  int64_t rightScalarFactor = 1;
+  Value matmulRetValue = extractValue;
+  if (leftTransform) {
+    // Get constant transform matrix AT
+    auto it = ATMatrices.find(key);
+    if (it == ATMatrices.end())
+      return Value();
+    const TransformMatrix &ATMatrix = it->second;
+
+    leftScalarFactor = ATMatrix.scalarFactor;
+    retRows = ATMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType);
+    // Multiply AT x m
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix T
+    auto it = AMatrices.find(key);
+    if (it == AMatrices.end())
+      return Value();
+    const TransformMatrix &AMatrix = it->second;
+
+    rightScalarFactor = AMatrix.scalarFactor;
+    auto matmulType =
+        RankedTensorType::get({retRows, AMatrix.cols}, elementType);
+    retCols = AMatrix.cols;
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType);
+    // Multiply y = (AT x m) x A
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Multiply scalar factor.
+  Value scalarFactor = rewriter.create<arith::ConstantOp>(
+      loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+  auto init =
+      rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType);
+
+  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
+  SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()),
+                                       identityAffineMap, identityAffineMap};
+  auto scalarMatrixOp = rewriter.create<linalg::GenericOp>(
+      loc, matmulType, ValueRange{scalarFactor, matmulRetValue},
+      ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2),
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value scalarVal = args[0];
+        Value matrixVal = args[1];
+        Value result = nestedBuilder.create<arith::MulFOp>(nestedLoc, scalarVal,
+                                                           matrixVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
+
+  // Insert slice y
+  // Insert (H, W) to (N, H, W, F)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  Value combinedVal = insert2DData(rewriter, loc, scalarMatrixOp.getResult(0),
+                                   iterArg, NIter, FIter, retRows, retCols,
+                                   /*outLoopIdx=*/0,
+                                   /*inLoopIdx=*/3, /*heightIdx=*/1,
+                                   /*widthIdx=*/2, /*destSize=*/4);
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
 Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
                             RankedTensorType alignedType) {
   Value alignedInput = rewriter.create<tensor::EmptyOp>(
@@ -289,6 +938,123 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
   return transformedOutput.getDefiningOp();
 }
 
+FailureOr<Operation *>
+decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradFilterTransformOp op) {
+  Location loc = op.getLoc();
+  Value filter = op.getFilter();
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  Value transformedFilter =
+      filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedFilter)
+    return failure();
+
+  rewriter.replaceOp(op, transformedFilter);
+
+  return transformedFilter.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
+                                      linalg::WinogradInputTransformOp op) {
+  Location loc = op.getLoc();
+  Value input = op.getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = inputH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = inputW != 1;
+  Value transformedInput =
+      inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
+                     op.getR(), leftTransform, rightTransform);
+  if (!transformedInput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedInput);
+
+  return transformedInput.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradOutputTransformOp op) {
+  Location loc = op.getLoc();
+  Value value = op.getValue();
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  int64_t valueH = valueShape[2];
+  int64_t valueW = valueShape[3];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = valueH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = valueW != 1;
+  Value transformedOutput =
+      outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedOutput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+class DecomposeWinogradFilterTransform final
+    : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradFilterTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class DecomposeWinogradInputTransform final
+    : public OpRewritePattern<linalg::WinogradInputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradInputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class DecomposeWinogradOutputTransform final
+    : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradOutputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
 class WinogradConv2DNhwcFhwc final
     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
 public:
@@ -323,5 +1089,12 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
   patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
 }
 
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<DecomposeWinogradFilterTransform>(context);
+  patterns.insert<DecomposeWinogradInputTransform>(context);
+  patterns.insert<DecomposeWinogradOutputTransform>(context);
+}
+
 } // end namespace linalg
 } // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
new file mode 100644
index 0000000000000..917d089c1981c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-decompose-winograd-ops | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = tensor.empty() : tensor<1x1x6x6x5x2xf32>
+  %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
+  %4 = tensor.empty() : tensor<1x1x6x6x2x5xf32>
+  %5 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%4 : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %5 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+  %6 = tensor.empty() : tensor<36x2x2xf32>
+  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %expanded = tensor.expand_shape %7 [[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
+  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<1x1x6x6x2x2xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %8 : tensor<2x4x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:   %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:   %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:   %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:   %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG5]], %[[ARG3]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S4]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S1]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x2xf32> to tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) {
+// CHECK-NEXT:      ^bb0(%[[IN:.*]]: f32, %[[IN_9:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:        %[[S17:.*]] = arith.mulf %[[IN]], %[[IN_9]] : f32
+// CHECK-NEXT:        linalg.yield %[[S17]] : f32
+// CHECK-NEXT:      } -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S16:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 12cb46a5968f1..5899f56da7345 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -127,6 +127,9 @@ struct TestLinalgTransforms
       *this, "test-winograd-conv2d",
       llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
       llvm::cl::init(false)};
+  Option<bool> testDecomposeWinogradOps{
+      *this, "test-decompose-winograd-ops",
+      llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
 };
 } // namespace
 
@@ -218,6 +221,12 @@ static void applyWinogradConv2D(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateDecomposeWinogradOpsPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -244,6 +253,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnnecessaryInputs(getOperation());
   if (testWinogradConv2D)
     return applyWinogradConv2D(getOperation());
+  if (testDecomposeWinogradOps)
+    return applyDecomposeWinogradOps(getOperation());
 }
 
 namespace mlir {

>From c94b1a3d2b30eefaa556b8ddf1f4767d89d72fe0 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 26 Jun 2024 09:43:43 +0100
Subject: [PATCH 04/22] Revert "[mlir][linalg] Add transform operator for
 Winograd Conv2D algorithm"

This reverts commit 374b0d5b83ce080bea690199380e270a36ad1c52.
---
 .../Linalg/TransformOps/LinalgTransformOps.td | 51 -----------
 .../Dialect/Linalg/Transforms/Transforms.h    |  7 --
 .../TransformOps/LinalgTransformOps.cpp       | 25 ------
 .../Linalg/Transforms/WinogradConv2D.cpp      |  6 --
 .../Linalg/transform-winograd-conv2d.mlir     | 88 -------------------
 5 files changed, 177 deletions(-)
 delete mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 68d0f713caad4..93e2c2db729da 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,55 +2587,4 @@ def MapCopyToThreadsOp :
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// Winograd Conv2D
-//===----------------------------------------------------------------------===//
-
-def WinogradConv2DOp : Op<Transform_Dialect,
-    "structured.winograd_conv2d",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
-  let description = [{
-    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-    matrix multiply. Before the matrix multiply, it will convert filter and
-    input into a format suitable for batched matrix multiply. After the matrix
-    multiply, it will convert output to the final result tensor.
-
-    The algorithm F(m x m, r x r) is
-
-    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
-    The size of output Y is m x m. The size of filter g is r x r. The size of
-    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-    transformation matrices.
-
-    #### Return modes:
-
-    This operation fails if `target` is unsupported. Otherwise, the operation
-    succeeds and returns a handle of the sequence that replaces the original
-    convolution.
-  }];
-
-  let arguments = (ins TransformHandleTypeInterface:$target,
-                       I64Attr:$m,
-                       I64Attr:$r);
-  let results = (outs TransformHandleTypeInterface:$transformed);
-
-  let assemblyFormat =
-    "$target attr-dict `:` functional-type($target, results)";
-
-  let builders = [
-    OpBuilder<(ins "Value":$target)>
-  ];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
-}
-
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index da107b66257a5..835aeaf2ffed3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,13 +1312,6 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
                                             linalg::BatchMatmulOp op,
                                             bool transposeLHS = true);
 
-/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
-/// F(m x m, r x r). m is the dimension size of output and r is the dimension
-/// size of filter.
-FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
-                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
-                                      int64_t r);
-
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d051b29e1f06f..bc02788f9c441 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,31 +3480,6 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
-//===----------------------------------------------------------------------===//
-// WinogradConv2DOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
-    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  rewriter.setInsertionPoint(target);
-  auto maybeTransformed =
-      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
-          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
-            return winogradConv2D(rewriter, op, getM(), getR());
-          })
-          .Default([&](Operation *op) {
-            return rewriter.notifyMatchFailure(op, "not supported");
-          });
-
-  if (failed(maybeTransformed))
-    return emitDefaultSilenceableFailure(target);
-
-  results.push_back(*maybeTransformed);
-  return DiagnosedSilenceableFailure::success();
-}
-
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index d1f4be8bbf29a..86e834d51f2fc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -311,12 +311,6 @@ class WinogradConv2DNhwcFhwc final
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
-FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
-                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
-                                      int64_t r) {
-  return winogradConv2DHelper(rewriter, op, m, r);
-}
-
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
deleted file mode 100644
index 1e74fea5a1c31..0000000000000
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ /dev/null
@@ -1,88 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
-
-func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-  %0 = tensor.empty() : tensor<2x8x8x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x8x8x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-  return %2 : tensor<2x8x8x2xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
-    transform.yield
-  }
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x8x8x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-  %0 = tensor.empty() : tensor<2x9x9x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x9x9x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
-  return %2 : tensor<2x9x9x2xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
-    transform.yield
-  }
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x9x9x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:   %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
-// CHECK-NEXT:   %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
-// CHECK-NEXT:   %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK-NEXT: }

>From 5a391881394094bfd747cb97bf023ed3df06923e Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 26 Jun 2024 09:44:19 +0100
Subject: [PATCH 05/22] Revert "[mlir][linalg] Implement Conv2D using Winograd
 Conv2D algorithm"

This reverts commit 4240341b4f06f1b77f63b0f619cae3804d88eb68.
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       | 114 -------
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 -
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  78 -----
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 -
 .../Linalg/Transforms/WinogradConv2D.cpp      | 321 ------------------
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 248 --------------
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  13 -
 7 files changed, 779 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
 delete mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index de1097b6ac27b..64c538367267d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,118 +154,4 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
-  let summary = "Winograd filter transform operator";
-  let description = [{
-    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-    matrix multiply. Before the matrix multiply, it will convert filter and
-    input into a format suitable for batched matrix multiply. After the matrix
-    multiply, it will convert output to the final result tensor.
-
-    The algorithm F(m x m, r x r) is
-
-    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
-    The size of output Y is m x m. The size of filter g is r x r. The size of
-    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-    transformation matrices.
-
-    This operator is defined to represent the high level concept of filter
-    transformation (G x g x G^T) in the Winograd Conv2D algorithm.
-  }];
-
-  let arguments = (ins AnyRankedTensor:$filter,
-                       AnyRankedTensor:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
-  );
-
-  let results = (outs AnyRankedTensor:$result);
-  let assemblyFormat = [{
-    attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
-    `ins` `(` $filter `:` type($filter) `)`
-    `outs` `(` $output `:` type($output) `)`
-    `->` type($result)
-  }];
-  let hasVerifier = 1;
-}
-
-def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
-  let summary = "Winograd input transform operator";
-  let description = [{
-    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-    matrix multiply. Before the matrix multiply, it will convert filter and
-    input into a format suitable for batched matrix multiply. After the matrix
-    multiply, it will convert output to the final result tensor.
-
-    The algorithm F(m x m, r x r) is
-
-    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
-    The size of output Y is m x m. The size of filter g is r x r. The size of
-    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-    transformation matrices.
-
-    This operator is defined to represent the high level concept of input
-    transformation (B^T x d x B) in the Winograd Conv2D algorithm.
-  }];
-
-  let arguments = (ins AnyRankedTensor:$input,
-                       AnyRankedTensor:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
-  );
-
-  let results = (outs AnyRankedTensor:$result);
-  let assemblyFormat = [{
-    attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
-    `ins` `(` $input `:` type($input) `)`
-    `outs` `(` $output `:` type($output) `)`
-    `->` type($result)
-  }];
-  let hasVerifier = 1;
-}
-
-def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
-  let summary = "Winograd output transform operator";
-  let description = [{
-    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-    matrix multiply. Before the matrix multiply, it will convert filter and
-    input into a format suitable for batched matrix multiply. After the matrix
-    multiply, it will convert output to the final result tensor.
-
-    The algorithm F(m x m, r x r) is
-
-    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
-    The size of output Y is m x m. The size of filter g is r x r. The size of
-    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-    transformation matrices.
-
-    This operator is defined to represent the high level concept of output
-    transformation (A^T x y x A) in the Winograd Conv2D algorithm.
-  }];
-
-  let arguments = (ins AnyRankedTensor:$value,
-                       AnyRankedTensor:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
-  );
-
-  let results = (outs AnyRankedTensor:$result);
-  let assemblyFormat = [{
-    attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
-    `ins` `(` $value `:` type($value) `)`
-    `outs` `(` $output `:` type($output) `)`
-    `->` type($result)
-  }];
-  let hasVerifier = 1;
-}
-
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 835aeaf2ffed3..05e97befdec1f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,10 +1692,6 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
-/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
-                                    int64_t r);
-
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7bf2a5bca037f..57d126603ebd7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,84 +2734,6 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
-//===----------------------------------------------------------------------===//
-// WinogradFilterTransformOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult WinogradFilterTransformOp::verify() {
-  auto filterType = cast<ShapedType>(getFilter().getType());
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  auto filterElemType = filterType.getElementType();
-  auto outputElemType = outputType.getElementType();
-  if (filterElemType != outputElemType) {
-    return emitOpError() << "expected element type of input " << filterElemType
-                         << " to match element type of output "
-                         << outputElemType;
-  }
-
-  unsigned filterRank = filterType.getRank();
-  if (filterRank != 4)
-    return emitOpError() << "expected rank of input is 4";
-
-  unsigned outputRank = outputType.getRank();
-  if (outputRank != 6)
-    return emitOpError() << "expected rank of output is 6";
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// WinogradInputTransformOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult WinogradInputTransformOp::verify() {
-  auto inputType = cast<ShapedType>(getInput().getType());
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  auto inputElemType = inputType.getElementType();
-  auto outputElemType = outputType.getElementType();
-  if (inputElemType != outputElemType) {
-    return emitOpError() << "expected element type of input " << inputElemType
-                         << " to match element type of output "
-                         << outputElemType;
-  }
-
-  unsigned inputRank = inputType.getRank();
-  if (inputRank != 4)
-    return emitOpError() << "expected rank of input is 4";
-
-  unsigned outputRank = outputType.getRank();
-  if (outputRank != 6)
-    return emitOpError() << "expected rank of output is 6";
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// WinogradOutputTransformOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult WinogradOutputTransformOp::verify() {
-  auto valueType = cast<ShapedType>(getValue().getType());
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  auto valueElemType = valueType.getElementType();
-  auto outputElemType = outputType.getElementType();
-  if (valueElemType != outputElemType) {
-    return emitOpError() << "expected element type of value " << valueElemType
-                         << " to match element type of output "
-                         << outputElemType;
-  }
-
-  unsigned valueRank = valueType.getRank();
-  if (valueRank != 6)
-    return emitOpError() << "expected rank of input is 6";
-
-  unsigned outputRank = outputType.getRank();
-  if (outputRank != 4)
-    return emitOpError() << "expected rank of output is 4";
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a7dcc29b5b9be..7e3dc56e0acdc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,7 +38,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
-  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
deleted file mode 100644
index 86e834d51f2fc..0000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ /dev/null
@@ -1,321 +0,0 @@
-//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Implement Winograd Conv2D algorithm. The implementation is based on the
-// paper: Fast Algorithms for Convolutional Neural Networks
-// (https://arxiv.org/abs/1509.09308)
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/MathExtras.h"
-
-namespace mlir {
-namespace linalg {
-
-namespace {
-
-using TransformMapKeyTy = std::pair<int, int>;
-
-// We use F(m, r) to define the size of minimal filtering algorithms.
-// m is the output dimension and r is the filter dimension. We can get
-// the input dimension, alpha, from the formula, alpha = m + r - 1.
-//
-// For example, when m = 2 and r = 3, we know its input size is 4.
-// The Conv2D will operate on 4x4 input data with 3x3 filter and get
-// 2x2 output result.
-constexpr TransformMapKeyTy F_2_3{2, 3};
-constexpr TransformMapKeyTy F_4_3{4, 3};
-constexpr TransformMapKeyTy F_2_5{2, 5};
-
-Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
-  auto type = cast<ShapedType>(data.getType());
-  auto elementType = type.getElementType();
-  auto shape = type.getShape();
-  auto collapseType = RankedTensorType::get(
-      {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
-      elementType);
-  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
-  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
-                                                  reassociation);
-}
-
-// This function generates linalg.batch_matmul to multiply input with filter.
-// linalg.batch_matmul only supports 3-dimension data sets. We can treat
-// tileH x tileW x H x W data as the 1-dimension data array. That is to convert
-// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we
-// can convert 6-dimension input data to 3-dimension representation that is
-// suitable for linalg.batch_matmul.
-//
-// Batched matmul will do the matrix multiply with the reduction on channel.
-//
-// We get
-//
-// %collapsed_input = tensor.collapse_shape %input
-// %collapsed_filter = tensor.collapse_shape %filter
-// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
-// %expanded_ret = tensor.expand_shape %ret
-//
-// After this function, we get return value with data layout
-// (tileH, tileW, H, W, N, F).
-Value matrixMultiply(RewriterBase &rewriter, Location loc,
-                     Value transformedFilter, Value transformedInput) {
-  auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
-  auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
-
-  // Batched matrix multiply
-  auto filterType = cast<ShapedType>(transformedFilter.getType());
-  auto filterShape = filterType.getShape();
-  auto inputType = cast<ShapedType>(transformedInput.getType());
-  auto inputElemType = inputType.getElementType();
-  auto inputShape = inputType.getShape();
-
-  auto matmulType = RankedTensorType::get(
-      {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
-       inputShape[4], filterShape[5]},
-      inputElemType);
-  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                inputElemType);
-
-  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
-      loc, matmulType, ValueRange({collapseInput, collapseFilter}),
-      ValueRange{init});
-
-  // Expand matmul result
-  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
-  auto expandType =
-      RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
-                             inputShape[3], inputShape[4], filterShape[5]},
-                            inputElemType);
-  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
-      loc, expandType, matmulOp.getResult(0), reassociation);
-  return expandOutput;
-}
-
-Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
-                            RankedTensorType alignedType) {
-  Value alignedInput = rewriter.create<tensor::EmptyOp>(
-      loc, alignedType.getShape(), alignedType.getElementType());
-
-  auto zeroIndex = rewriter.getIndexAttr(0);
-  auto oneIndex = rewriter.getIndexAttr(1);
-  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
-  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
-
-  auto valueType = cast<ShapedType>(value.getType());
-  auto valueShape = valueType.getShape();
-  SmallVector<OpFoldResult, 4> sizes;
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[0]));
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[1]));
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[2]));
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[3]));
-
-  return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput,
-                                                offsets, sizes, strides);
-}
-
-Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
-                               Value value, RankedTensorType extractedType) {
-  auto zeroIndex = rewriter.getIndexAttr(0);
-  auto oneIndex = rewriter.getIndexAttr(1);
-  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
-  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
-
-  auto extractedShape = extractedType.getShape();
-  SmallVector<OpFoldResult, 4> sizes;
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0]));
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1]));
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2]));
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3]));
-
-  return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
-                                                 offsets, sizes, strides);
-}
-
-bool hasAllOneValues(DenseIntElementsAttr attr) {
-  return llvm::all_of(
-      attr, [](const APInt &element) { return element.getSExtValue() == 1; });
-}
-
-FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
-                                            linalg::Conv2DNhwcFhwcOp convOp,
-                                            int64_t m, int64_t r) {
-  Value input = convOp.getInputs()[0];
-  Value filter = convOp.getInputs()[1];
-  Value output = convOp.getOutputs()[0];
-  auto inputType = cast<ShapedType>(input.getType());
-  auto filterType = cast<ShapedType>(filter.getType());
-  auto outputType = cast<ShapedType>(output.getType());
-
-  if (!inputType.hasStaticShape())
-    return rewriter.notifyMatchFailure(convOp,
-                                       "expected a static shape for the input");
-
-  if (!filterType.hasStaticShape())
-    return rewriter.notifyMatchFailure(
-        convOp, "expected a static shape for the filter");
-
-  if (!hasAllOneValues(convOp.getDilations()))
-    return rewriter.notifyMatchFailure(convOp,
-                                       "expected all ones for dilations");
-
-  if (!hasAllOneValues(convOp.getStrides()))
-    return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
-
-  auto filterShape = filterType.getShape();
-  int64_t filterF = filterShape[0];
-  int64_t filterH = filterShape[1];
-  int64_t filterW = filterShape[2];
-  int64_t filterC = filterShape[3];
-  auto inputShape = inputType.getShape();
-  int64_t inputN = inputShape[0];
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
-  int64_t inputC = inputShape[3];
-  auto outputShape = outputType.getShape();
-  int64_t outputN = outputShape[0];
-  int64_t outputH = outputShape[1];
-  int64_t outputW = outputShape[2];
-  int64_t outputF = outputShape[3];
-
-  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
-  bool isSupportedFilter = false;
-  if (filterH == filterW && filterH == r)
-    isSupportedFilter = true;
-  if (filterH == r && filterW == 1)
-    isSupportedFilter = true;
-  if (filterH == 1 && filterW == r)
-    isSupportedFilter = true;
-
-  if (!isSupportedFilter)
-    return rewriter.notifyMatchFailure(
-        convOp, "only support filter (r x r), (r x 1) or (1 x r)");
-
-  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5)
-  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
-      F_2_3, F_4_3, F_2_5};
-
-  TransformMapKeyTy key = {m, r};
-  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
-  // If we cannot find the constant transformation matrix, it means we do
-  // not support this configuration yet.
-  if (it == validConfigs.end())
-    return failure();
-
-  // All the criterias are satisfied. We can do Winograd Conv2D.
-  Location loc = convOp.getLoc();
-
-  // For F(m x 1, r x 1), we only need to do left side transform.
-  bool leftTransform = filterH != 1;
-  // For F(1 x m, 1 x r), we only need to do right side transform.
-  bool rightTransform = filterW != 1;
-  int64_t heightM = leftTransform ? m : 1;
-  int64_t widthM = rightTransform ? m : 1;
-  int64_t heightR = leftTransform ? r : 1;
-  int64_t widthR = rightTransform ? r : 1;
-
-  // --- Create operator for filter transform ---
-  Type elementType = filterType.getElementType();
-  int64_t alphaH = heightM + heightR - 1;
-  int64_t alphaW = widthM + widthR - 1;
-  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
-  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
-  auto retType = RankedTensorType::get(
-      {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType);
-  Value retValue =
-      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
-  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
-      loc, retType, filter, retValue, m, r);
-
-  // --- Create operator for input transform ---
-
-  // When input size - (r - 1) is not aligned with output tile size, we need to
-  // pad the input data to create the full tiles as tiling.
-  int64_t alignedInputH = tileH * heightM + (heightR - 1);
-  int64_t alignedInputW = tileW * widthM + (widthR - 1);
-  if (alignedInputH != inputH || alignedInputW != inputW) {
-    auto alignedInputType = RankedTensorType::get(
-        {inputN, alignedInputH, alignedInputW, inputC}, elementType);
-    input = insertToAlignedTensor(rewriter, loc, input, alignedInputType);
-  }
-
-  retType = RankedTensorType::get(
-      {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType);
-  retValue =
-      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
-  auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
-      loc, retType, input, retValue, m, r);
-
-  Value matmulRet =
-      matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
-
-  // --- Create operator for output transform ---
-
-  // When output size is not aligned with output tile size, we need to pad the
-  // output buffer to insert the full tiles after tiling.
-  int64_t alignedOutputH = tileH * heightM;
-  int64_t alignedOutputW = tileW * widthM;
-  bool isOutputUnaligned =
-      ((alignedOutputH != outputH) || (alignedOutputW != outputW));
-  if (isOutputUnaligned) {
-    auto alignedOutputType = RankedTensorType::get(
-        {outputN, alignedOutputH, alignedOutputW, outputF}, elementType);
-    output = insertToAlignedTensor(rewriter, loc, output, alignedOutputType);
-    outputType = alignedOutputType;
-  }
-
-  Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
-      loc, outputType, matmulRet, output, m, r);
-
-  // When output size is not aligned with output tile size, extract the
-  // value from the padded buffer.
-  if (isOutputUnaligned) {
-    transformedOutput = extractFromAlignedTensor(
-        rewriter, loc, transformedOutput,
-        RankedTensorType::get({outputN, outputH, outputW, outputF},
-                              elementType));
-  }
-
-  rewriter.replaceOp(convOp, transformedOutput);
-
-  return transformedOutput.getDefiningOp();
-}
-
-class WinogradConv2DNhwcFhwc final
-    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
-      : OpRewritePattern(context), m(m), r(r) {}
-
-  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
-                                PatternRewriter &rewriter) const override {
-    if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
-      return failure();
-
-    return success();
-  }
-
-private:
-  int64_t m;
-  int64_t r;
-};
-} // end anonymous namespace
-
-//===----------------------------------------------------------------------===//
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
-                                    int64_t r) {
-  MLIRContext *context = patterns.getContext();
-  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
-}
-
-} // end namespace linalg
-} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
deleted file mode 100644
index 6cca3c602d4c0..0000000000000
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ /dev/null
@@ -1,248 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
-
-func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %2 : tensor<2x4x4x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_4x4_3x3
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
-  %0 = tensor.empty() : tensor<2x2x2x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x2x2x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
-  return %2 : tensor<2x2x2x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_2x2_5x5
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x2x2x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x1x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x1x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
-  return %2 : tensor<2x1x4x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_1x4_1x3
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x1x4x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x1x1x6x2x5xf32>) -> tensor<1x1x1x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x1x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x1x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-  return %2 : tensor<2x4x1x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_4x1_3x1
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x4x1x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<1x1x6x1x2x5xf32>) -> tensor<1x1x6x1x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-  %0 = tensor.empty() : tensor<2x8x8x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x8x8x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-  return %2 : tensor<2x8x8x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_aligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:  return %[[S8]] : tensor<2x8x8x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-  %0 = tensor.empty() : tensor<2x9x9x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x9x9x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
-  return %2 : tensor<2x9x9x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x9x9x2xf32>
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:  %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
-// CHECK-NEXT:  %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
-// CHECK-NEXT:  %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK-NEXT:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %2 : tensor<2x4x4x2xf32>
-}
-
-// CHECK-LABEL: conv2d_unsupported_1
-// CHECK: linalg.conv_2d_nhwc_fhwc
-
-// -----
-
-func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %2 : tensor<2x4x4x2xf32>
-}
-
-// CHECK-LABEL: conv2d_unsupported_2
-// CHECK: linalg.conv_2d_nhwc_fhwc
-
-// -----
-
-func.func @conv2d_unsupported_3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
-  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<2x3x3x5xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
-  return %0 : tensor<?x?x?x?xf32>
-}
-
-// CHECK-LABEL: conv2d_unsupported_3
-// CHECK: linalg.conv_2d_nhwc_fhwc
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 12cb46a5968f1..4892fa2f99a7c 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,10 +123,6 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
-  Option<bool> testWinogradConv2D{
-      *this, "test-winograd-conv2d",
-      llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
-      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -211,13 +207,6 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
-static void applyWinogradConv2D(func::FuncOp funcOp) {
-  RewritePatternSet patterns(funcOp.getContext());
-  populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
-  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
-  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -242,8 +231,6 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
-  if (testWinogradConv2D)
-    return applyWinogradConv2D(getOperation());
 }
 
 namespace mlir {

>From 690662771c806a2f7301bdc4dedc983047c41d35 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:24:07 +0100
Subject: [PATCH 06/22] [mlir][linalg] Implement Conv2D using Winograd Conv2D
 algorithm

Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       | 117 ++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 107 ++++++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 334 ++++++++++++++++++
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 193 ++++++++++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  13 +
 7 files changed, 769 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..a9007c8db3078 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,121 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
+def Linalg_WinogradFilterTransformOp :
+    Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
+  let summary = "Winograd filter transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of filter
+    transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
+                       TensorRankOf<[AnyType], [4]>:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs TensorRankOf<[AnyType], [4]>:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $filter `:` type($filter) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradInputTransformOp :
+    Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
+  let summary = "Winograd input transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of input
+    transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
+                       TensorRankOf<[AnyType], [6]>:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs TensorRankOf<[AnyType], [6]>:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradOutputTransformOp :
+    Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
+  let summary = "Winograd output transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of output
+    transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
+                       TensorRankOf<[AnyType], [4]>:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs TensorRankOf<[AnyType], [4]>:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $value `:` type($value) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..835aeaf2ffed3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..1283315f2eaef 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,6 +2734,113 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradFilterTransformOp::verify() {
+  auto filterType = cast<ShapedType>(getFilter().getType());
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t r = getR();
+
+  if (filterH != r && filterH != 1)
+    return failure();
+  if (filterW != r && filterW != 1)
+    return failure();
+  if (filterH == 1 && filterW == 1)
+    return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradInputTransformOp::verify() {
+  auto inputType = cast<ShapedType>(getInput().getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputH = outputShape[0];
+  int64_t outputW = outputShape[1];
+  int64_t outputTileH = outputShape[2];
+  int64_t outputTileW = outputShape[3];
+  int m = getM();
+  int r = getR();
+  bool leftTransform = inputH != 1;
+  bool rightTransform = inputW != 1;
+
+  if (!leftTransform && !rightTransform)
+    return failure();
+
+  if (leftTransform) {
+    int64_t tileH = (inputH - (r - 1)) / m;
+    if (inputH != tileH * m + (r - 1))
+      return failure();
+    if (tileH != outputTileH)
+      return failure();
+    if (outputH != m + r - 1)
+      return failure();
+  }
+
+  if (rightTransform) {
+    int64_t tileW = (inputW - (r - 1)) / m;
+    if (inputW != tileW * m + (r - 1))
+      return failure();
+    if (tileW != outputTileW)
+      return failure();
+    if (outputW != m + r - 1)
+      return failure();
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+  auto valueType = cast<ShapedType>(getValue().getType());
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  int64_t valueTileH = valueShape[2];
+  int64_t valueTileW = valueShape[3];
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int m = getM();
+  int r = getR();
+  bool leftTransform = valueH != 1;
+  bool rightTransform = valueW != 1;
+
+  if (!leftTransform && !rightTransform)
+    return failure();
+
+  if (leftTransform) {
+    if (valueH != m + r - 1)
+      return failure();
+    if (outputH != m * valueTileH)
+      return failure();
+  }
+
+  if (rightTransform) {
+    if (valueW != m + r - 1)
+      return failure();
+    if (outputW != m * valueTileW)
+      return failure();
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
+  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..6b46f9e07abf8
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,334 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+/// This function generates linalg.batch_matmul to multiply input with filter.
+/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
+/// tileH x tileW x H x W data as the 1-dimensional data array. That is to
+/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
+/// way, we can convert 6-dimensional inputs to 3-dimensional representation
+/// that is suitable for linalg.batch_matmul.
+///
+/// Batched matmul will do the matrix multiply with the reduction on channel.
+///
+/// We get
+///
+/// %collapsed_input = tensor.collapse_shape %input
+/// %collapsed_filter = tensor.collapse_shape %filter
+/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+/// %expanded_ret = tensor.expand_shape %ret
+///
+/// After this function, we get return value with data layout
+/// (tileH, tileW, H, W, N, F).
+static Value matrixMultiply(RewriterBase &rewriter, Location loc,
+                            Value transformedFilter, Value transformedInput,
+                            Type outputElementType) {
+  // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
+  auto filterType = cast<ShapedType>(transformedFilter.getType());
+  assert(filterType.hasStaticShape() && "only support static shapes.");
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  Type filterElementType = filterType.getElementType();
+  auto filterReassocType = RankedTensorType::get(
+      {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
+      filterElementType);
+  SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
+  Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, filterReassocType, transformedFilter, filterReassoc);
+
+  // Convert (alphaH, alphaW, tileH, tileW, N, C) to
+  // (alphaH x alphaW, tileH x tileW x N, C) for input.
+  auto inputType = cast<ShapedType>(transformedInput.getType());
+  assert(inputType.hasStaticShape() && "only support static shapes.");
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  Type inputElementType = inputType.getElementType();
+  auto inputReassocType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1],
+       inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
+      inputElementType);
+  SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
+  Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, inputReassocType, transformedInput, inputReassoc);
+
+  // Batched matrix multiply.
+  auto matmulType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1],
+       inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
+      outputElementType);
+  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                outputElementType);
+
+  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+      ValueRange{init});
+
+  // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
+  // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
+  SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
+  auto outputReassocType =
+      RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
+                             inputShape[3], inputShape[4], filterShape[3]},
+                            outputElementType);
+  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+      loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
+  return expandOutput;
+}
+
+/// Create an empty tensor with alignedType and insert the value into the
+/// created empty tensor with aligned size.
+static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc,
+                                   Value value,
+                                   ArrayRef<int64_t> alignedShape) {
+  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+  auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  SmallVector<OpFoldResult, 6> lowIndices(alignedShape.size(), zeroIndex);
+  SmallVector<OpFoldResult, 6> highIndices;
+  for (unsigned i = 0; i < alignedShape.size(); ++i) {
+    highIndices.emplace_back(
+        rewriter.getIndexAttr(alignedShape[i] - valueShape[i]));
+  }
+  auto alignedType = RankedTensorType::get(alignedShape, elementType);
+  Value pad_value = rewriter.create<arith::ConstantOp>(
+      loc, elementType, rewriter.getZeroAttr(elementType));
+  return rewriter.create<tensor::PadOp>(loc, alignedType, value, lowIndices,
+                                        highIndices, pad_value);
+}
+
+/// Extract sub-tensor with extractedType from value.
+static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
+                                      Value value,
+                                      RankedTensorType extractedType) {
+  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+  OpFoldResult oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  ArrayRef<int64_t> extractedShape = extractedType.getShape();
+  SmallVector<OpFoldResult> sizes =
+      getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
+
+  return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
+                                                 offsets, sizes, strides);
+}
+
+/// Utility function to check all values in the attribute are 1.
+static bool hasAllOneValues(DenseIntElementsAttr attr) {
+  return llvm::all_of(
+      attr, [](const APInt &element) { return element.getSExtValue() == 1; });
+}
+
+/// A helper function to convert linalg.conv_2d_nhwc_fhwc to
+/// linalg.winograd_*_transform ops.
+static FailureOr<Operation *>
+winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
+                     int64_t m, int64_t r) {
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+  auto inputType = cast<ShapedType>(input.getType());
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto outputType = cast<ShapedType>(output.getType());
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  if (!hasAllOneValues(convOp.getStrides()))
+    return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
+
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputN = outputShape[0];
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int64_t outputF = outputShape[3];
+
+  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
+  bool isSupportedFilter = false;
+  if (filterH == filterW && filterH == r)
+    isSupportedFilter = true;
+  if (filterH == r && filterW == 1)
+    isSupportedFilter = true;
+  if (filterH == 1 && filterW == r)
+    isSupportedFilter = true;
+
+  if (!isSupportedFilter)
+    return rewriter.notifyMatchFailure(
+        convOp, "only support filter (r x r), (r x 1) or (1 x r)");
+
+  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
+  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+      F_2_3, F_4_3, F_2_5};
+
+  TransformMapKeyTy key = {m, r};
+  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+  // If we cannot find the constant transformation matrix, it means we do
+  // not support this configuration yet.
+  if (it == validConfigs.end())
+    return failure();
+
+  // All the criterias are satisfied. We can do Winograd Conv2D.
+  Location loc = convOp.getLoc();
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  int64_t heightM = leftTransform ? m : 1;
+  int64_t widthM = rightTransform ? m : 1;
+  int64_t heightR = leftTransform ? r : 1;
+  int64_t widthR = rightTransform ? r : 1;
+
+  // --- Create operation for filter transform ---
+  Type filterElementType = filterType.getElementType();
+  int64_t alphaH = heightM + heightR - 1;
+  int64_t alphaW = widthM + widthR - 1;
+  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
+  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
+  auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
+                                       filterElementType);
+  Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
+                                                    filterElementType);
+  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
+      loc, retType, filter, retValue, m, r);
+
+  // --- Create operation for input transform ---
+
+  // When input size - (r - 1) is not aligned with output tile size, we need to
+  // pad the input data to create the full tiles as tiling.
+  Type inputElementType = inputType.getElementType();
+  int64_t alignedInputH = tileH * heightM + (heightR - 1);
+  int64_t alignedInputW = tileW * widthM + (widthR - 1);
+  if (alignedInputH != inputH || alignedInputW != inputW) {
+    input = insertToAlignedTensor(
+        rewriter, loc, input, {inputN, alignedInputH, alignedInputW, inputC});
+  }
+
+  retType = RankedTensorType::get(
+      {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
+  retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
+                                              inputElementType);
+  auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
+      loc, retType, input, retValue, m, r);
+
+  Type outputElementType = outputType.getElementType();
+  Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
+                                   transformedInput, outputElementType);
+
+  // --- Create operation for output transform ---
+
+  // When output size is not aligned with output tile size, we need to pad the
+  // output buffer to insert the full tiles after tiling.
+  int64_t alignedOutputH = tileH * heightM;
+  int64_t alignedOutputW = tileW * widthM;
+  bool isOutputUnaligned =
+      ((alignedOutputH != outputH) || (alignedOutputW != outputW));
+  if (isOutputUnaligned) {
+    auto alignedOutputType = RankedTensorType::get(
+        {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
+    output = insertToAlignedTensor(rewriter, loc, output,
+                                   alignedOutputType.getShape());
+    outputType = alignedOutputType;
+  }
+
+  Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
+      loc, outputType, matmulRet, output, m, r);
+
+  // When output size is not aligned with output tile size, extract the
+  // value from the padded buffer.
+  if (isOutputUnaligned) {
+    transformedOutput = extractFromAlignedTensor(
+        rewriter, loc, transformedOutput,
+        RankedTensorType::get({outputN, outputH, outputW, outputF},
+                              outputElementType));
+  }
+
+  rewriter.replaceOp(convOp, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+/// A rewrite pattern for Winograd Conv2D algorithm.
+class WinogradConv2DNhwcFhwc final
+    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
+      : OpRewritePattern(context), m(m), r(r) {}
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
+      return failure();
+
+    return success();
+  }
+
+private:
+  int64_t m;
+  int64_t r;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
+}
+
+} // end namespace linalg
+} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
new file mode 100644
index 0000000000000..ec11a6ef8fbee
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -0,0 +1,193 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
+
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%out : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %0 : tensor<2x2x2x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_2x2_5x5
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%out : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+  return %0 : tensor<2x1x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_1x4_1x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%out : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %0 : tensor<2x4x1x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_4x1_3x1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_aligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+  return %0 : tensor<2x9x9x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT:  ^bb0
+// CHECK-NEXT:    tensor.yield %[[CST]] : f32
+// CHECK-NEXT:  } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK-NEXT:  %[[PADDED_1:.*]] = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT:  ^bb0
+// CHECK-NEXT:    tensor.yield %[[CST]] : f32
+// CHECK-NEXT:  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf16>, tensor<2x3x3x5xf16>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_type_promotion
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:   return %[[S6]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_1
+// CHECK: linalg.conv_2d_nhwc_fhwc
+
+// -----
+
+func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_2
+// CHECK: linalg.conv_2d_nhwc_fhwc
+
+// -----
+
+func.func @conv2d_unsupported_3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<2x3x3x5xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_3
+// CHECK: linalg.conv_2d_nhwc_fhwc
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4892fa2f99a7c..12cb46a5968f1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,6 +123,10 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
+  Option<bool> testWinogradConv2D{
+      *this, "test-winograd-conv2d",
+      llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -207,6 +211,13 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyWinogradConv2D(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
+  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -231,6 +242,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
+  if (testWinogradConv2D)
+    return applyWinogradConv2D(getOperation());
 }
 
 namespace mlir {

>From bb8087930cfd79a3d4ebf6a8e959f4c30bb70fcf Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:49:08 +0100
Subject: [PATCH 07/22] [mlir][linalg] Add transform operator for Winograd
 Conv2D algorithm

Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
---
 .../Linalg/TransformOps/LinalgTransformOps.td | 51 +++++++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  7 ++
 .../TransformOps/LinalgTransformOps.cpp       | 25 ++++++
 .../Linalg/Transforms/WinogradConv2D.cpp      |  6 ++
 .../Linalg/transform-winograd-conv2d.mlir     | 88 +++++++++++++++++++
 5 files changed, 177 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..68d0f713caad4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+    "structured.winograd_conv2d",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    convolution.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       I64Attr:$m,
+                       I64Attr:$r);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 835aeaf2ffed3..da107b66257a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
                                             linalg::BatchMatmulOp op,
                                             bool transposeLHS = true);
 
+/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
+/// F(m x m, r x r). m is the dimension size of output and r is the dimension
+/// size of filter.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..d051b29e1f06f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+            return winogradConv2D(rewriter, op, getM(), getR());
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 6b46f9e07abf8..843db0c069813 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -324,6 +324,12 @@ class WinogradConv2DNhwcFhwc final
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r) {
+  return winogradConv2DHelper(rewriter, op, m, r);
+}
+
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
new file mode 100644
index 0000000000000..1e74fea5a1c31
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<2x8x8x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x8x8x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %2 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = tensor.empty() : tensor<2x9x9x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x9x9x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+  return %2 : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x9x9x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:   %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
+// CHECK-NEXT:   %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
+// CHECK-NEXT:   %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK-NEXT:   %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }

>From cc23f43cfab82f1c0b9ddbf6cacd29a20f99d825 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 26 Jun 2024 12:26:15 +0100
Subject: [PATCH 08/22] Address ftynse's comments

---
 .../Linalg/TransformOps/LinalgTransformOps.td |   8 +-
 .../TransformOps/LinalgTransformOps.cpp       |  26 ++--
 .../Linalg/transform-winograd-conv2d.mlir     | 112 ++++++++----------
 3 files changed, 69 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 68d0f713caad4..5ef56bc97fef1 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2597,7 +2597,7 @@ def WinogradConv2DOp : Op<Transform_Dialect,
      TransformOpInterface, TransformEachOpTrait,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    Winograd Conv2D algorithm will convert linalg Conv2D operation into batched
     matrix multiply. Before the matrix multiply, it will convert filter and
     input into a format suitable for batched matrix multiply. After the matrix
     multiply, it will convert output to the final result tensor.
@@ -2612,9 +2612,9 @@ def WinogradConv2DOp : Op<Transform_Dialect,
 
     #### Return modes:
 
-    This operation fails if `target` is unsupported. Otherwise, the operation
-    succeeds and returns a handle of the sequence that replaces the original
-    convolution.
+    This operation produces a silenceable failure if `target` is unsupported.
+    Otherwise, the operation succeeds and returns a handle of the sequence that
+    replaces the original convolution.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d051b29e1f06f..e0f2d00400d63 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3489,17 +3489,21 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  auto maybeTransformed =
-      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
-          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
-            return winogradConv2D(rewriter, op, getM(), getR());
-          })
-          .Default([&](Operation *op) {
-            return rewriter.notifyMatchFailure(op, "not supported");
-          });
-
-  if (failed(maybeTransformed))
-    return emitDefaultSilenceableFailure(target);
+  FailureOr<Operation *> maybeTransformed = failure();
+  bool supported = TypeSwitch<Operation *, bool>(target)
+                       .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+                         maybeTransformed =
+                             winogradConv2D(rewriter, op, getM(), getR());
+                         return true;
+                       })
+                       .Default([&](Operation *op) {
+                         op->emitError("not supported");
+                         return false;
+                       });
+
+  if (supported && failed(maybeTransformed)) {
+    return emitSilenceableError() << "apply Winograd Conv2D failed";
+  }
 
   results.push_back(*maybeTransformed);
   return DiagnosedSilenceableFailure::success();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
index 1e74fea5a1c31..0a2dcc035ebd3 100644
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -1,13 +1,8 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s
 
-func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-  %0 = tensor.empty() : tensor<2x8x8x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x8x8x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-  return %2 : tensor<2x8x8x2xf32>
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
 }
 
 module attributes {transform.with_named_sequence} {
@@ -18,38 +13,17 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-LABEL: func.func @conv2d
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x8x8x2xf32>
-// CHECK-NEXT: }
+// CHECK: linalg.winograd_filter_transform m(4) r(3)
+// CHECK: linalg.winograd_input_transform m(4) r(3)
+// CHECK: linalg.batch_matmul
+// CHECK: linalg.winograd_output_transform m(4) r(3)
 
 // -----
 
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-  %0 = tensor.empty() : tensor<2x9x9x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x9x9x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
-  return %2 : tensor<2x9x9x2xf32>
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+  return %0 : tensor<2x9x9x2xf32>
 }
 
 module attributes {transform.with_named_sequence} {
@@ -60,29 +34,43 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x9x9x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:   %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
-// CHECK-NEXT:   %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
-// CHECK-NEXT:   %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK-NEXT: }
+// CHECK:       linalg.winograd_filter_transform m(4) r(3)
+// CHECK:       tensor.pad
+// CHECK-SAME:  low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK:       linalg.winograd_input_transform m(4) r(3)
+// CHECK:       tensor.pad
+// CHECK-SAME:  low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK:       linalg.winograd_output_transform m(4) r(3)
+
+// -----
+
+func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  // expected-error @+1 {{not supported}}
+  %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
+  return %0 : tensor<2x?x?x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @+1 {{apply Winograd Conv2D failed}}
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}

>From 48e24b4f7798be53b46a49511cbc524c2b2162a4 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 26 Jun 2024 15:13:28 +0100
Subject: [PATCH 09/22] Revert "[mlir][linalg] Decompose winograd operators"

This reverts commit 24c4f957ae673c2955fc0674f91e488813d59350.
---
 .../Dialect/Linalg/Transforms/Transforms.h    |   3 -
 .../Linalg/Transforms/WinogradConv2D.cpp      | 773 ------------------
 .../Linalg/winograd-conv2d-rewrite.mlir       | 105 ---
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  11 -
 4 files changed, 892 deletions(-)
 delete mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bb7ec590faad0..da107b66257a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1703,9 +1703,6 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r);
 
-/// Patterns to decompose Winograd operators.
-void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
-
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index d245723c85646..d1f4be8bbf29a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -12,10 +12,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -26,156 +23,6 @@ namespace linalg {
 
 namespace {
 
-// clang-format off
-// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
-// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
-// m is the output dimension and r is the filter dimension, is
-//
-// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
-//
-// g is filter and d is input data. We need to prepare 6 constant
-// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
-//
-// The following tables define these constant transformation matrices for
-// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
-constexpr float G_2x2_3x3[] = {
-   -1,     0,   0,
- 1./2, -1./2, 1./2,
- 1./2,  1./2, 1./2,
-    0,     0,    1
-};
-
-constexpr float GT_2x2_3x3[] = {
-   -1,  1./2, 1./2, 0,
-    0, -1./2, 1./2, 0,
-    0,  1./2, 1./2, 1
-};
-
-constexpr float BT_2x2_3x3[] = {
-   -1,    0,   1,   0,
-    0,   -1,   1,   0,
-    0,    1,   1,   0,
-    0,   -1,   0,   1
-};
-
-constexpr float B_2x2_3x3[] = {
-   -1,    0,   0,   0,
-    0,   -1,   1,  -1,
-    1,    1,   1,   0,
-    0,    0,   0,   1
-};
-
-constexpr float AT_2x2_3x3[] = {
-    1,    1,   1,   0,
-    0,   -1,   1,   1
-};
-
-constexpr float A_2x2_3x3[] = {
-    1,    0,
-    1,   -1,
-    1,    1,
-    0,    1
-};
-
-constexpr float G_4x4_3x3[] = {
-     1,     0,     0,
- -1./3,  1./3, -1./3,
- -1./3, -1./3, -1./3,
- 1./12, -1./6,  1./3,
- 1./12,  1./6,  1./3,
-     0,     0,     1
-};
-
-constexpr float GT_4x4_3x3[] = {
- 1,  -1./3, -1./3, 1./12, 1./12, 0,
- 0,   1./3, -1./3, -1./6,  1./6, 0,
- 0,  -1./3, -1./3,  1./3,  1./3, 1
-};
-
-constexpr float BT_4x4_3x3[] = {
- 1./4,     0, -5./16,      0, 1./16,     0,
-    0,  1./4,  -1./4, -1./16, 1./16,     0,
-    0, -1./4,  -1./4,  1./16, 1./16,     0,
-    0,  1./4,  -1./8,  -1./4,  1./8,     0,
-    0, -1./4,  -1./8,   1./4,  1./8,     0,
-    0,  1./4,      0, -5./16,     0, 1./16
-};
-
-constexpr float B_4x4_3x3[] = {
-   1./4,      0,     0,     0,     0,      0,
-      0,   1./4, -1./4,  1./4, -1./4,   1./4,
- -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
-      0, -1./16, 1./16, -1./4,  1./4, -5./16,
-  1./16,  1./16, 1./16,  1./8,  1./8,      0,
-      0,      0,     0,     0,     0,  1./16
-};
-
-constexpr float AT_4x4_3x3[] = {
- 1./8,  1./4, 1./4,  1./8, 1./8,    0,
-    0, -1./4, 1./4, -1./4, 1./4,    0,
-    0,  1./4, 1./4,  1./2, 1./2,    0,
-    0, -1./4, 1./4,    -1,    1, 1./2
-};
-
-constexpr float A_4x4_3x3[] = {
-  1./8,     0,    0,     0,
-  1./4, -1./4, 1./4, -1./4,
-  1./4,  1./4, 1./4,  1./4,
-  1./8, -1./4, 1./2,    -1,
-  1./8,  1./4, 1./2,     1,
-     0,     0,    0,  1./2
-};
-
-constexpr float G_2x2_5x5[] = {
-     1,     0,      0,      0,      0,
-  1./6, -1./6,   1./6,  -1./6,   1./6,
- -1./6, -1./6,  -1./6,  -1./6,  -1./6,
--4./15, 2./15, -1./15,  1./30, -1./60,
- 1./60, 1./30,  1./15,  2./15,  4./15,
-     0,     0,      0,      0,      1
-};
-
-constexpr float GT_2x2_5x5[] = {
-   1,  1./6, -1./6, -4./15, 1./60, 0,
-   0, -1./6, -1./6,  2./15, 1./30, 0,
-   0,  1./6, -1./6, -1./15, 1./15, 0,
-   0, -1./6, -1./6,  1./30, 2./15, 0,
-   0,  1./6, -1./6, -1./60, 4./15, 1
-};
-
-constexpr float BT_2x2_5x5[] = {
- 1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
-    0,   1./8,  1./16,  -5./16,   1./8,    0,
-    0,  -1./8, -5./16,  -1./16,   1./8,    0,
-    0,   1./4,  -1./8,   -1./4,   1./8,    0,
-    0,  -1./8,  -1./4,    1./8,   1./4,    0,
-    0,   1./8,  3./16,   -1./4, -3./16, 1./8
-};
-
-constexpr float B_2x2_5x5[] = {
-   1./8,      0,      0,     0,     0,      0,
-  3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
-  -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
- -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
-   1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
-      0,      0,      0,     0,     0,   1./8
-};
-
-constexpr float AT_2x2_5x5[] = {
-  1./2,  1, 1,  2, 1,    0,
-     0, -1, 1, -1, 2, 1./2
-};
-
-constexpr float A_2x2_5x5[] = {
- 1./2,    0,
-    1,   -1,
-    1,    1,
-    2,   -1,
-    1,    2,
-    0, 1./2
-};
-// clang-format on
-
 using TransformMapKeyTy = std::pair<int, int>;
 
 // We use F(m, r) to define the size of minimal filtering algorithms.
@@ -189,92 +36,6 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
-struct TransformMatrix {
-  TransformMatrix(const float *table, int64_t rows, int64_t cols,
-                  int64_t scalarFactor = 1)
-      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
-
-  const float *table;
-  int64_t rows;
-  int64_t cols;
-  int64_t scalarFactor;
-};
-
-Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
-                              TransformMatrix transform, Type type) {
-  ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
-
-  return rewriter.create<arith::ConstantOp>(
-      loc, DenseFPElementsAttr::get(
-               RankedTensorType::get(
-                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
-               const_vec));
-}
-
-Value extract2DData(RewriterBase &rewriter, Location loc, Value source,
-                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
-                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
-                    int64_t srcSize) {
-  auto sourceType = cast<ShapedType>(source.getType());
-  Type elementType = sourceType.getElementType();
-  auto sourceShape = sourceType.getShape();
-  int64_t height = sourceShape[heightIdx];
-  int64_t width = sourceShape[widthIdx];
-
-  auto zeroIndex = rewriter.getIndexAttr(0);
-  auto oneIndex = rewriter.getIndexAttr(1);
-  SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
-  offsets[outLoopIdx] = outLoopIndex;
-  offsets[inLoopIdx] = inLoopIndex;
-  SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
-  sizes[heightIdx] = rewriter.getIndexAttr(height);
-  sizes[widthIdx] = rewriter.getIndexAttr(width);
-  SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
-  SmallVector<int64_t> targetShape(srcSize, 1);
-  targetShape[heightIdx] = height;
-  targetShape[widthIdx] = width;
-
-  auto targetType = RankedTensorType::get(targetShape, elementType);
-  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
-      loc, targetType, source, offsets, sizes, strides);
-
-  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
-  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, extractFilterOp, extractFilterType);
-
-  return extractFilter;
-}
-
-Value insert2DData(RewriterBase &rewriter, Location loc, Value source,
-                   Value dest, Value outLoopIndex, Value inLoopIndex,
-                   int64_t height, int64_t width, int64_t outLoopIdx,
-                   int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
-                   int64_t destSize) {
-  auto sourceType = cast<ShapedType>(source.getType());
-  Type elementType = sourceType.getElementType();
-  SmallVector<int64_t> sliceShape(destSize, 1);
-  sliceShape[heightIdx] = height;
-  sliceShape[widthIdx] = width;
-  auto init = rewriter.create<tensor::EmptyOp>(loc, sliceShape, elementType);
-  auto result = tensor::createCanonicalRankReducingInsertSliceOp(rewriter, loc,
-                                                                 source, init);
-
-  auto zeroIndex = rewriter.getIndexAttr(0);
-  auto oneIndex = rewriter.getIndexAttr(1);
-  SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
-  retOffsets[outLoopIdx] = outLoopIndex;
-  retOffsets[inLoopIdx] = inLoopIndex;
-  SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
-  retSizes[heightIdx] = rewriter.getIndexAttr(height);
-  retSizes[widthIdx] = rewriter.getIndexAttr(width);
-  SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
-
-  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-      loc, result, dest, retOffsets, retSizes, strides);
-
-  return insertSliceOp;
-}
-
 Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
   auto type = cast<ShapedType>(data.getType());
   auto elementType = type.getElementType();
@@ -287,261 +48,6 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
                                                   reassociation);
 }
 
-// This function transforms the filter. The data layout of the filter is FHWC.
-// The transformation matrix is 2-dimension. We need to extract H x W from
-// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
-// After the transformation, we get
-//
-// scf.for %f = lo_f to hi_f step 1
-//   scf.for %c = lo_c to hi_c step 1
-//     %extracted = extract filter<h x w> from filter<f x h x w x c>
-//     %ret = linalg.matmul G, %extracted
-//     %ret = linalg.matmul %ret, GT
-//     %inserted = insert %ret into filter<tile_h x tile_w x h x w x c x f>
-//
-Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
-                      Value retValue, int64_t m, int64_t r,
-                      bool leftTransform = true, bool rightTransform = true) {
-  // Map from (m, r) to G transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
-      GMatrices = {
-          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
-          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
-          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
-      };
-
-  // Map from (m, r) to GT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
-      GTMatrices = {
-          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
-          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
-          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
-      };
-
-  auto filterType = cast<ShapedType>(filter.getType());
-  Type elementType = filterType.getElementType();
-  auto filterShape = filterType.getShape(); // F, H, W, C
-  int64_t filterF = filterShape[0];
-  int64_t filterH = filterShape[1];
-  int64_t filterW = filterShape[2];
-  int64_t filterC = filterShape[3];
-
-  if (filterH != r && filterH != 1)
-    return Value();
-  if (filterW != r && filterW != 1)
-    return Value();
-
-  // Return shape is <H x W x C x F>
-  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
-  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
-  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-  auto outerForOp =
-      rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
-  Block *outerForBody = outerForOp.getBody();
-  rewriter.setInsertionPointToStart(outerForBody);
-  Value FIter = outerForBody->getArgument(0);
-
-  auto innerForOp = rewriter.create<scf::ForOp>(
-      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
-  Block *innerForBody = innerForOp.getBody();
-  rewriter.setInsertionPointToStart(innerForBody);
-  Value CIter = innerForBody->getArgument(0);
-
-  // Extract (H, W) from (F, H, W, C)
-  auto extractFilter = extract2DData(
-      rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
-      /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
-
-  TransformMapKeyTy key = {m, r};
-  int64_t retRows = 1;
-  Value matmulRetValue = extractFilter;
-  if (leftTransform) {
-    // Get constant transform matrix G
-    auto it = GMatrices.find(key);
-    if (it == GMatrices.end())
-      return Value();
-    const TransformMatrix &GMatrix = it->second;
-
-    retRows = GMatrix.rows;
-    auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-
-    Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType);
-    // Multiply G x g
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  if (rightTransform) {
-    // Get constant transform matrix GT
-    auto it = GTMatrices.find(key);
-    if (it == GTMatrices.end())
-      return Value();
-    const TransformMatrix &GTMatrix = it->second;
-
-    auto matmulType =
-        RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-
-    Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType);
-    // Multiply u = (G x g) x GT
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  // Insert (H, W) to (1, 1, H, W, C, F)
-  Value iterArg = innerForOp.getRegionIterArgs()[0];
-  int64_t retHeight = leftTransform ? m + r - 1 : 1;
-  int64_t retWidth = rightTransform ? m + r - 1 : 1;
-  auto insertSliceOp = insert2DData(
-      rewriter, loc, matmulRetValue, iterArg, FIter, CIter, retHeight, retWidth,
-      /*outLoopIdx=*/5, /*inLoopIdx=*/4, /*heightIdx=*/2, /*widthIdx=*/3,
-      /*destSize=*/6);
-
-  rewriter.create<scf::YieldOp>(loc, insertSliceOp);
-
-  rewriter.setInsertionPointToEnd(outerForBody);
-  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
-
-  rewriter.setInsertionPointAfter(outerForOp);
-
-  return outerForOp.getResult(0);
-}
-
-// This function transforms the input. The data layout of the input is NHWC.
-// The transformation matrix is 2-dimension. We need to extract H x W from
-// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
-// After the transformation, we get
-//
-// scf.for %n = lo_n to hi_n step 1
-//   scf.for %c = lo_c to hi_c step 1
-//     %extracted = extract input<h x w> from input<n x h x w x c>
-//     %ret = linalg.matmul BT, %extracted
-//     %ret = linalg.matmul %ret, B
-//     %inserted = insert %ret into input<h x w x n x c>
-//
-Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
-                     Value retValue, int64_t m, int64_t r,
-                     bool leftTransform = true, bool rightTransform = true) {
-  // Map from (m, r) to BT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
-      BTMatrices = {
-          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
-          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
-          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
-      };
-
-  // Map from (m, r) to B transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
-      BMatrices = {
-          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
-          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
-          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
-      };
-
-  auto inputType = cast<ShapedType>(input.getType());
-  Type elementType = inputType.getElementType();
-  auto inputShape = inputType.getShape(); // N, H, W, C
-  int64_t inputN = inputShape[0];
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
-  int64_t inputC = inputShape[3];
-  int64_t alphaH = leftTransform ? m + r - 1 : 1;
-  int64_t alphaW = rightTransform ? m + r - 1 : 1;
-
-  if (inputH != alphaH && inputH != 1)
-    return Value();
-  if (inputW != alphaW && inputW != 1)
-    return Value();
-
-  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
-  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
-  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-
-  auto outerForOp =
-      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, retValue);
-  Block *outerForBody = outerForOp.getBody();
-  rewriter.setInsertionPointToStart(outerForBody);
-  Value NIter = outerForBody->getArgument(0);
-
-  auto innerForOp = rewriter.create<scf::ForOp>(
-      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
-  Block *innerForBody = innerForOp.getBody();
-  rewriter.setInsertionPointToStart(innerForBody);
-  Value CIter = innerForBody->getArgument(0);
-
-  // Extract (H, W) from (N, H, W, C)
-  auto extractInput = extract2DData(
-      rewriter, loc, input, NIter, CIter, /*outLoopIdx=*/0,
-      /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
-
-  TransformMapKeyTy key = {m, r};
-  int64_t retRows = 1;
-  int64_t retCols = 1;
-  Value matmulRetValue = extractInput;
-  if (leftTransform) {
-    // Get constant transform matrix BT
-    auto it = BTMatrices.find(key);
-    if (it == BTMatrices.end())
-      return Value();
-    const TransformMatrix &BTMatrix = it->second;
-
-    retRows = BTMatrix.rows;
-    auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-
-    Value BT =
-        create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type());
-    // Multiply BT x d
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  if (rightTransform) {
-    // Get constant transform matrix B
-    auto it = BMatrices.find(key);
-    if (it == BMatrices.end())
-      return Value();
-    const TransformMatrix &BMatrix = it->second;
-
-    retCols = BMatrix.cols;
-    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-    Value B =
-        create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type());
-    // Multiply v = (BT x d) x B
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  // Insert v
-  // Insert (H, W) to (1, 1, H, W, N, C)
-  Value iterArg = innerForOp.getRegionIterArgs()[0];
-  auto combinedVal = insert2DData(
-      rewriter, loc, matmulRetValue, iterArg, NIter, CIter, retRows, retCols,
-      /*outLoopIdx=*/4, /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3,
-      /*destSize=*/6);
-
-  rewriter.create<scf::YieldOp>(loc, combinedVal);
-
-  rewriter.setInsertionPointToEnd(outerForBody);
-  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
-
-  rewriter.setInsertionPointAfter(outerForOp);
-
-  return outerForOp.getResult(0);
-}
-
 // This function generates linalg.batch_matmul to multiply input with filter.
 // linalg.batch_matmul only supports 3-dimension data sets. We can treat
 // tileH x tileW x H x W data as the 1-dimension data array. That is to convert
@@ -594,161 +100,6 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc,
   return expandOutput;
 }
 
-// This function transforms the output. The data layout of the output is HWNF.
-// The transformation matrix is 2-dimension. We need to extract H x W from
-// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
-// After the transformation, we get
-//
-// scf.for %n = lo_n to hi_n step 1
-//   scf.for %f = lo_f to hi_f step 1
-//     %extracted = extract input<h x w> from result<h x w x n x f>
-//     %ret = linalg.matmul AT, %extracted
-//     %ret = linalg.matmul %ret, A
-//     %inserted = insert %ret into ret<n x h x w x f>
-//
-Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
-                      Value output, int64_t m, int64_t r,
-                      bool leftTransform = true, bool rightTransform = true) {
-  // Map from (m, r) to AT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
-      ATMatrices = {
-          {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
-          {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
-          {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
-      };
-
-  // Map from (m, r) to A transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
-      AMatrices = {
-          {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
-          {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
-          {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
-      };
-
-  auto valueType = cast<ShapedType>(value.getType());
-  Type elementType = valueType.getElementType();
-  auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F
-  int64_t valueH = valueShape[2];
-  int64_t valueW = valueShape[3];
-  int64_t valueN = valueShape[4];
-  int64_t valueF = valueShape[5];
-  int64_t alphaH = leftTransform ? m + r - 1 : 1;
-  int64_t alphaW = rightTransform ? m + r - 1 : 1;
-
-  if (valueH != alphaH && valueH != 1)
-    return Value();
-  if (valueW != alphaW && valueW != 1)
-    return Value();
-
-  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
-  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
-  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-
-  auto outerForOp =
-      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output);
-  Block *outerForBody = outerForOp.getBody();
-  rewriter.setInsertionPointToStart(outerForBody);
-  Value NIter = outerForBody->getArgument(0);
-
-  auto innerForOp = rewriter.create<scf::ForOp>(
-      loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
-  Block *innerForBody = innerForOp.getBody();
-  rewriter.setInsertionPointToStart(innerForBody);
-  Value FIter = innerForBody->getArgument(0);
-
-  // Extract (H, W) from (1, 1, H, W, N, F)
-  auto extractValue = extract2DData(
-      rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/4,
-      /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, /*srcSize=*/6);
-
-  TransformMapKeyTy key = {m, r};
-  int64_t retRows = 1;
-  int64_t retCols = 1;
-  int64_t leftScalarFactor = 1;
-  int64_t rightScalarFactor = 1;
-  Value matmulRetValue = extractValue;
-  if (leftTransform) {
-    // Get constant transform matrix AT
-    auto it = ATMatrices.find(key);
-    if (it == ATMatrices.end())
-      return Value();
-    const TransformMatrix &ATMatrix = it->second;
-
-    leftScalarFactor = ATMatrix.scalarFactor;
-    retRows = ATMatrix.rows;
-    auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-
-    Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType);
-    // Multiply AT x m
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  if (rightTransform) {
-    // Get constant transform matrix T
-    auto it = AMatrices.find(key);
-    if (it == AMatrices.end())
-      return Value();
-    const TransformMatrix &AMatrix = it->second;
-
-    rightScalarFactor = AMatrix.scalarFactor;
-    auto matmulType =
-        RankedTensorType::get({retRows, AMatrix.cols}, elementType);
-    retCols = AMatrix.cols;
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-
-    Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType);
-    // Multiply y = (AT x m) x A
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  // Multiply scalar factor.
-  Value scalarFactor = rewriter.create<arith::ConstantOp>(
-      loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
-  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
-  auto init =
-      rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType);
-
-  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
-  SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()),
-                                       identityAffineMap, identityAffineMap};
-  auto scalarMatrixOp = rewriter.create<linalg::GenericOp>(
-      loc, matmulType, ValueRange{scalarFactor, matmulRetValue},
-      ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2),
-      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        Value scalarVal = args[0];
-        Value matrixVal = args[1];
-        Value result = nestedBuilder.create<arith::MulFOp>(nestedLoc, scalarVal,
-                                                           matrixVal);
-        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
-      });
-
-  // Insert slice y
-  // Insert (H, W) to (N, H, W, F)
-  Value iterArg = innerForOp.getRegionIterArgs()[0];
-  Value combinedVal = insert2DData(rewriter, loc, scalarMatrixOp.getResult(0),
-                                   iterArg, NIter, FIter, retRows, retCols,
-                                   /*outLoopIdx=*/0,
-                                   /*inLoopIdx=*/3, /*heightIdx=*/1,
-                                   /*widthIdx=*/2, /*destSize=*/4);
-
-  rewriter.create<scf::YieldOp>(loc, combinedVal);
-
-  rewriter.setInsertionPointToEnd(outerForBody);
-  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
-
-  rewriter.setInsertionPointAfter(outerForOp);
-
-  return outerForOp.getResult(0);
-}
-
 Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
                             RankedTensorType alignedType) {
   Value alignedInput = rewriter.create<tensor::EmptyOp>(
@@ -938,123 +289,6 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
   return transformedOutput.getDefiningOp();
 }
 
-FailureOr<Operation *>
-decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
-                                       linalg::WinogradFilterTransformOp op) {
-  Location loc = op.getLoc();
-  Value filter = op.getFilter();
-  auto filterType = cast<ShapedType>(filter.getType());
-  auto filterShape = filterType.getShape();
-  int64_t filterH = filterShape[1];
-  int64_t filterW = filterShape[2];
-
-  // For F(m x 1, r x 1), we only need to do left side transform.
-  bool leftTransform = filterH != 1;
-  // For F(1 x m, 1 x r), we only need to do right side transform.
-  bool rightTransform = filterW != 1;
-  Value transformedFilter =
-      filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
-                      op.getR(), leftTransform, rightTransform);
-  if (!transformedFilter)
-    return failure();
-
-  rewriter.replaceOp(op, transformedFilter);
-
-  return transformedFilter.getDefiningOp();
-}
-
-FailureOr<Operation *>
-decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
-                                      linalg::WinogradInputTransformOp op) {
-  Location loc = op.getLoc();
-  Value input = op.getInput();
-  auto inputType = cast<ShapedType>(input.getType());
-  auto inputShape = inputType.getShape();
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
-
-  // For F(m x 1, r x 1), we only need to do left side transform.
-  bool leftTransform = inputH != 1;
-  // For F(1 x m, 1 x r), we only need to do right side transform.
-  bool rightTransform = inputW != 1;
-  Value transformedInput =
-      inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
-                     op.getR(), leftTransform, rightTransform);
-  if (!transformedInput)
-    return failure();
-
-  rewriter.replaceOp(op, transformedInput);
-
-  return transformedInput.getDefiningOp();
-}
-
-FailureOr<Operation *>
-decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
-                                       linalg::WinogradOutputTransformOp op) {
-  Location loc = op.getLoc();
-  Value value = op.getValue();
-  auto valueType = cast<ShapedType>(value.getType());
-  auto valueShape = valueType.getShape();
-  int64_t valueH = valueShape[2];
-  int64_t valueW = valueShape[3];
-
-  // For F(m x 1, r x 1), we only need to do left side transform.
-  bool leftTransform = valueH != 1;
-  // For F(1 x m, 1 x r), we only need to do right side transform.
-  bool rightTransform = valueW != 1;
-  Value transformedOutput =
-      outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
-                      op.getR(), leftTransform, rightTransform);
-  if (!transformedOutput)
-    return failure();
-
-  rewriter.replaceOp(op, transformedOutput);
-
-  return transformedOutput.getDefiningOp();
-}
-
-class DecomposeWinogradFilterTransform final
-    : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
-                                PatternRewriter &rewriter) const override {
-    if (failed(decomposeWinogradFilterTransformHelper(rewriter, op)))
-      return failure();
-
-    return success();
-  }
-};
-
-class DecomposeWinogradInputTransform final
-    : public OpRewritePattern<linalg::WinogradInputTransformOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
-                                PatternRewriter &rewriter) const override {
-    if (failed(decomposeWinogradInputTransformHelper(rewriter, op)))
-      return failure();
-
-    return success();
-  }
-};
-
-class DecomposeWinogradOutputTransform final
-    : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
-                                PatternRewriter &rewriter) const override {
-    if (failed(decomposeWinogradOutputTransformHelper(rewriter, op)))
-      return failure();
-
-    return success();
-  }
-};
-
 class WinogradConv2DNhwcFhwc final
     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
 public:
@@ -1089,12 +323,5 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
   patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
 }
 
-void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
-  MLIRContext *context = patterns.getContext();
-  patterns.insert<DecomposeWinogradFilterTransform>(context);
-  patterns.insert<DecomposeWinogradInputTransform>(context);
-  patterns.insert<DecomposeWinogradOutputTransform>(context);
-}
-
 } // end namespace linalg
 } // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
deleted file mode 100644
index 917d089c1981c..0000000000000
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ /dev/null
@@ -1,105 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-decompose-winograd-ops | FileCheck %s
-
-#map = affine_map<(d0, d1, d2, d3) -> (0)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-  %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
-  %4 = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-  %5 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%4 : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
-  %collapsed = tensor.collapse_shape %3 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-  %collapsed_0 = tensor.collapse_shape %5 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
-  %6 = tensor.empty() : tensor<36x2x2xf32>
-  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-  %expanded = tensor.expand_shape %7 [[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<1x1x6x6x2x2xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %8 : tensor<2x4x4x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
-// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: func.func @conv2d_4x4_3x3
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK-DAG:   %[[CST:.*]] = arith.constant 1.024000e+03 : f32
-// CHECK-DAG:   %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
-// CHECK-DAG:   %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
-// CHECK-DAG:   %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
-// CHECK-DAG:   %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
-// CHECK-DAG:   %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
-// CHECK-DAG:   %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
-// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<1x1x6x6x5x2xf32>) {
-// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x5x2xf32>) {
-// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
-// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
-// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x3xf32>
-// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
-// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG5]], %[[ARG3]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:    }
-// CHECK-NEXT:    scf.yield %[[S9]] : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:  }
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S4]]) -> (tensor<1x1x6x6x2x5xf32>) {
-// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x2x5xf32>) {
-// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
-// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
-// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:    }
-// CHECK-NEXT:    scf.yield %[[S9]] : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:  }
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S1]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x2xf32> to tensor<1x1x6x6x1x1xf32>
-// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> to tensor<6x6xf32>
-// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
-// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
-// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:      %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) {
-// CHECK-NEXT:      ^bb0(%[[IN:.*]]: f32, %[[IN_9:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:        %[[S17:.*]] = arith.mulf %[[IN]], %[[IN_9]] : f32
-// CHECK-NEXT:        linalg.yield %[[S17]] : f32
-// CHECK-NEXT:      } -> tensor<4x4xf32>
-// CHECK-NEXT:      %[[S16:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
-// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<2x4x4x2xf32>
-// CHECK-NEXT:    }
-// CHECK-NEXT:    scf.yield %[[S9]] : tensor<2x4x4x2xf32>
-// CHECK-NEXT:  }
-// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
-// CHECK-NEXT:}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 5899f56da7345..12cb46a5968f1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -127,9 +127,6 @@ struct TestLinalgTransforms
       *this, "test-winograd-conv2d",
       llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
       llvm::cl::init(false)};
-  Option<bool> testDecomposeWinogradOps{
-      *this, "test-decompose-winograd-ops",
-      llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
 };
 } // namespace
 
@@ -221,12 +218,6 @@ static void applyWinogradConv2D(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
-static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
-  RewritePatternSet patterns(funcOp.getContext());
-  populateDecomposeWinogradOpsPatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -253,8 +244,6 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnnecessaryInputs(getOperation());
   if (testWinogradConv2D)
     return applyWinogradConv2D(getOperation());
-  if (testDecomposeWinogradOps)
-    return applyDecomposeWinogradOps(getOperation());
 }
 
 namespace mlir {

>From afcddc2030bea39dc983a56d6351b14de935ada1 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 26 Jun 2024 15:13:43 +0100
Subject: [PATCH 10/22] Revert "[mlir][linalg] Add transform operator for
 Winograd Conv2D algorithm"

This reverts commit 374b0d5b83ce080bea690199380e270a36ad1c52.
---
 .../Linalg/TransformOps/LinalgTransformOps.td | 51 -----------
 .../Dialect/Linalg/Transforms/Transforms.h    |  7 --
 .../TransformOps/LinalgTransformOps.cpp       | 25 ------
 .../Linalg/Transforms/WinogradConv2D.cpp      |  6 --
 .../Linalg/transform-winograd-conv2d.mlir     | 88 -------------------
 5 files changed, 177 deletions(-)
 delete mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 68d0f713caad4..93e2c2db729da 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,55 +2587,4 @@ def MapCopyToThreadsOp :
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// Winograd Conv2D
-//===----------------------------------------------------------------------===//
-
-def WinogradConv2DOp : Op<Transform_Dialect,
-    "structured.winograd_conv2d",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
-  let description = [{
-    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-    matrix multiply. Before the matrix multiply, it will convert filter and
-    input into a format suitable for batched matrix multiply. After the matrix
-    multiply, it will convert output to the final result tensor.
-
-    The algorithm F(m x m, r x r) is
-
-    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
-    The size of output Y is m x m. The size of filter g is r x r. The size of
-    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-    transformation matrices.
-
-    #### Return modes:
-
-    This operation fails if `target` is unsupported. Otherwise, the operation
-    succeeds and returns a handle of the sequence that replaces the original
-    convolution.
-  }];
-
-  let arguments = (ins TransformHandleTypeInterface:$target,
-                       I64Attr:$m,
-                       I64Attr:$r);
-  let results = (outs TransformHandleTypeInterface:$transformed);
-
-  let assemblyFormat =
-    "$target attr-dict `:` functional-type($target, results)";
-
-  let builders = [
-    OpBuilder<(ins "Value":$target)>
-  ];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
-}
-
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index da107b66257a5..835aeaf2ffed3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,13 +1312,6 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
                                             linalg::BatchMatmulOp op,
                                             bool transposeLHS = true);
 
-/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
-/// F(m x m, r x r). m is the dimension size of output and r is the dimension
-/// size of filter.
-FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
-                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
-                                      int64_t r);
-
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d051b29e1f06f..bc02788f9c441 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,31 +3480,6 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
-//===----------------------------------------------------------------------===//
-// WinogradConv2DOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
-    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  rewriter.setInsertionPoint(target);
-  auto maybeTransformed =
-      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
-          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
-            return winogradConv2D(rewriter, op, getM(), getR());
-          })
-          .Default([&](Operation *op) {
-            return rewriter.notifyMatchFailure(op, "not supported");
-          });
-
-  if (failed(maybeTransformed))
-    return emitDefaultSilenceableFailure(target);
-
-  results.push_back(*maybeTransformed);
-  return DiagnosedSilenceableFailure::success();
-}
-
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index d1f4be8bbf29a..86e834d51f2fc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -311,12 +311,6 @@ class WinogradConv2DNhwcFhwc final
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
-FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
-                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
-                                      int64_t r) {
-  return winogradConv2DHelper(rewriter, op, m, r);
-}
-
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
deleted file mode 100644
index 1e74fea5a1c31..0000000000000
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ /dev/null
@@ -1,88 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
-
-func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-  %0 = tensor.empty() : tensor<2x8x8x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x8x8x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-  return %2 : tensor<2x8x8x2xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
-    transform.yield
-  }
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x8x8x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-  %0 = tensor.empty() : tensor<2x9x9x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x9x9x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
-  return %2 : tensor<2x9x9x2xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
-    transform.yield
-  }
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x9x9x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:   %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
-// CHECK-NEXT:   %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
-// CHECK-NEXT:   %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK-NEXT: }

>From 67a5701bf19199c466e1ddd212ed6407e1ebecb0 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 26 Jun 2024 15:14:03 +0100
Subject: [PATCH 11/22] Revert "[mlir][linalg] Implement Conv2D using Winograd
 Conv2D algorithm"

This reverts commit 4240341b4f06f1b77f63b0f619cae3804d88eb68.
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       | 114 -------
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 -
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  78 -----
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 -
 .../Linalg/Transforms/WinogradConv2D.cpp      | 321 ------------------
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 248 --------------
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  13 -
 7 files changed, 779 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
 delete mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index de1097b6ac27b..64c538367267d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,118 +154,4 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
-  let summary = "Winograd filter transform operator";
-  let description = [{
-    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-    matrix multiply. Before the matrix multiply, it will convert filter and
-    input into a format suitable for batched matrix multiply. After the matrix
-    multiply, it will convert output to the final result tensor.
-
-    The algorithm F(m x m, r x r) is
-
-    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
-    The size of output Y is m x m. The size of filter g is r x r. The size of
-    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-    transformation matrices.
-
-    This operator is defined to represent the high level concept of filter
-    transformation (G x g x G^T) in the Winograd Conv2D algorithm.
-  }];
-
-  let arguments = (ins AnyRankedTensor:$filter,
-                       AnyRankedTensor:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
-  );
-
-  let results = (outs AnyRankedTensor:$result);
-  let assemblyFormat = [{
-    attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
-    `ins` `(` $filter `:` type($filter) `)`
-    `outs` `(` $output `:` type($output) `)`
-    `->` type($result)
-  }];
-  let hasVerifier = 1;
-}
-
-def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
-  let summary = "Winograd input transform operator";
-  let description = [{
-    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-    matrix multiply. Before the matrix multiply, it will convert filter and
-    input into a format suitable for batched matrix multiply. After the matrix
-    multiply, it will convert output to the final result tensor.
-
-    The algorithm F(m x m, r x r) is
-
-    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
-    The size of output Y is m x m. The size of filter g is r x r. The size of
-    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-    transformation matrices.
-
-    This operator is defined to represent the high level concept of input
-    transformation (B^T x d x B) in the Winograd Conv2D algorithm.
-  }];
-
-  let arguments = (ins AnyRankedTensor:$input,
-                       AnyRankedTensor:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
-  );
-
-  let results = (outs AnyRankedTensor:$result);
-  let assemblyFormat = [{
-    attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
-    `ins` `(` $input `:` type($input) `)`
-    `outs` `(` $output `:` type($output) `)`
-    `->` type($result)
-  }];
-  let hasVerifier = 1;
-}
-
-def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
-  let summary = "Winograd output transform operator";
-  let description = [{
-    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-    matrix multiply. Before the matrix multiply, it will convert filter and
-    input into a format suitable for batched matrix multiply. After the matrix
-    multiply, it will convert output to the final result tensor.
-
-    The algorithm F(m x m, r x r) is
-
-    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
-    The size of output Y is m x m. The size of filter g is r x r. The size of
-    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-    transformation matrices.
-
-    This operator is defined to represent the high level concept of output
-    transformation (A^T x y x A) in the Winograd Conv2D algorithm.
-  }];
-
-  let arguments = (ins AnyRankedTensor:$value,
-                       AnyRankedTensor:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
-  );
-
-  let results = (outs AnyRankedTensor:$result);
-  let assemblyFormat = [{
-    attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
-    `ins` `(` $value `:` type($value) `)`
-    `outs` `(` $output `:` type($output) `)`
-    `->` type($result)
-  }];
-  let hasVerifier = 1;
-}
-
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 835aeaf2ffed3..05e97befdec1f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,10 +1692,6 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
-/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
-                                    int64_t r);
-
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7bf2a5bca037f..57d126603ebd7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,84 +2734,6 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
-//===----------------------------------------------------------------------===//
-// WinogradFilterTransformOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult WinogradFilterTransformOp::verify() {
-  auto filterType = cast<ShapedType>(getFilter().getType());
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  auto filterElemType = filterType.getElementType();
-  auto outputElemType = outputType.getElementType();
-  if (filterElemType != outputElemType) {
-    return emitOpError() << "expected element type of input " << filterElemType
-                         << " to match element type of output "
-                         << outputElemType;
-  }
-
-  unsigned filterRank = filterType.getRank();
-  if (filterRank != 4)
-    return emitOpError() << "expected rank of input is 4";
-
-  unsigned outputRank = outputType.getRank();
-  if (outputRank != 6)
-    return emitOpError() << "expected rank of output is 6";
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// WinogradInputTransformOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult WinogradInputTransformOp::verify() {
-  auto inputType = cast<ShapedType>(getInput().getType());
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  auto inputElemType = inputType.getElementType();
-  auto outputElemType = outputType.getElementType();
-  if (inputElemType != outputElemType) {
-    return emitOpError() << "expected element type of input " << inputElemType
-                         << " to match element type of output "
-                         << outputElemType;
-  }
-
-  unsigned inputRank = inputType.getRank();
-  if (inputRank != 4)
-    return emitOpError() << "expected rank of input is 4";
-
-  unsigned outputRank = outputType.getRank();
-  if (outputRank != 6)
-    return emitOpError() << "expected rank of output is 6";
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// WinogradOutputTransformOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult WinogradOutputTransformOp::verify() {
-  auto valueType = cast<ShapedType>(getValue().getType());
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  auto valueElemType = valueType.getElementType();
-  auto outputElemType = outputType.getElementType();
-  if (valueElemType != outputElemType) {
-    return emitOpError() << "expected element type of value " << valueElemType
-                         << " to match element type of output "
-                         << outputElemType;
-  }
-
-  unsigned valueRank = valueType.getRank();
-  if (valueRank != 6)
-    return emitOpError() << "expected rank of input is 6";
-
-  unsigned outputRank = outputType.getRank();
-  if (outputRank != 4)
-    return emitOpError() << "expected rank of output is 4";
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a7dcc29b5b9be..7e3dc56e0acdc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,7 +38,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
-  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
deleted file mode 100644
index 86e834d51f2fc..0000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ /dev/null
@@ -1,321 +0,0 @@
-//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Implement Winograd Conv2D algorithm. The implementation is based on the
-// paper: Fast Algorithms for Convolutional Neural Networks
-// (https://arxiv.org/abs/1509.09308)
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/MathExtras.h"
-
-namespace mlir {
-namespace linalg {
-
-namespace {
-
-using TransformMapKeyTy = std::pair<int, int>;
-
-// We use F(m, r) to define the size of minimal filtering algorithms.
-// m is the output dimension and r is the filter dimension. We can get
-// the input dimension, alpha, from the formula, alpha = m + r - 1.
-//
-// For example, when m = 2 and r = 3, we know its input size is 4.
-// The Conv2D will operate on 4x4 input data with 3x3 filter and get
-// 2x2 output result.
-constexpr TransformMapKeyTy F_2_3{2, 3};
-constexpr TransformMapKeyTy F_4_3{4, 3};
-constexpr TransformMapKeyTy F_2_5{2, 5};
-
-Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
-  auto type = cast<ShapedType>(data.getType());
-  auto elementType = type.getElementType();
-  auto shape = type.getShape();
-  auto collapseType = RankedTensorType::get(
-      {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
-      elementType);
-  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
-  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
-                                                  reassociation);
-}
-
-// This function generates linalg.batch_matmul to multiply input with filter.
-// linalg.batch_matmul only supports 3-dimension data sets. We can treat
-// tileH x tileW x H x W data as the 1-dimension data array. That is to convert
-// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we
-// can convert 6-dimension input data to 3-dimension representation that is
-// suitable for linalg.batch_matmul.
-//
-// Batched matmul will do the matrix multiply with the reduction on channel.
-//
-// We get
-//
-// %collapsed_input = tensor.collapse_shape %input
-// %collapsed_filter = tensor.collapse_shape %filter
-// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
-// %expanded_ret = tensor.expand_shape %ret
-//
-// After this function, we get return value with data layout
-// (tileH, tileW, H, W, N, F).
-Value matrixMultiply(RewriterBase &rewriter, Location loc,
-                     Value transformedFilter, Value transformedInput) {
-  auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
-  auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
-
-  // Batched matrix multiply
-  auto filterType = cast<ShapedType>(transformedFilter.getType());
-  auto filterShape = filterType.getShape();
-  auto inputType = cast<ShapedType>(transformedInput.getType());
-  auto inputElemType = inputType.getElementType();
-  auto inputShape = inputType.getShape();
-
-  auto matmulType = RankedTensorType::get(
-      {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
-       inputShape[4], filterShape[5]},
-      inputElemType);
-  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                inputElemType);
-
-  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
-      loc, matmulType, ValueRange({collapseInput, collapseFilter}),
-      ValueRange{init});
-
-  // Expand matmul result
-  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
-  auto expandType =
-      RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
-                             inputShape[3], inputShape[4], filterShape[5]},
-                            inputElemType);
-  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
-      loc, expandType, matmulOp.getResult(0), reassociation);
-  return expandOutput;
-}
-
-Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
-                            RankedTensorType alignedType) {
-  Value alignedInput = rewriter.create<tensor::EmptyOp>(
-      loc, alignedType.getShape(), alignedType.getElementType());
-
-  auto zeroIndex = rewriter.getIndexAttr(0);
-  auto oneIndex = rewriter.getIndexAttr(1);
-  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
-  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
-
-  auto valueType = cast<ShapedType>(value.getType());
-  auto valueShape = valueType.getShape();
-  SmallVector<OpFoldResult, 4> sizes;
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[0]));
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[1]));
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[2]));
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[3]));
-
-  return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput,
-                                                offsets, sizes, strides);
-}
-
-Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
-                               Value value, RankedTensorType extractedType) {
-  auto zeroIndex = rewriter.getIndexAttr(0);
-  auto oneIndex = rewriter.getIndexAttr(1);
-  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
-  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
-
-  auto extractedShape = extractedType.getShape();
-  SmallVector<OpFoldResult, 4> sizes;
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0]));
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1]));
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2]));
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3]));
-
-  return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
-                                                 offsets, sizes, strides);
-}
-
-bool hasAllOneValues(DenseIntElementsAttr attr) {
-  return llvm::all_of(
-      attr, [](const APInt &element) { return element.getSExtValue() == 1; });
-}
-
-FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
-                                            linalg::Conv2DNhwcFhwcOp convOp,
-                                            int64_t m, int64_t r) {
-  Value input = convOp.getInputs()[0];
-  Value filter = convOp.getInputs()[1];
-  Value output = convOp.getOutputs()[0];
-  auto inputType = cast<ShapedType>(input.getType());
-  auto filterType = cast<ShapedType>(filter.getType());
-  auto outputType = cast<ShapedType>(output.getType());
-
-  if (!inputType.hasStaticShape())
-    return rewriter.notifyMatchFailure(convOp,
-                                       "expected a static shape for the input");
-
-  if (!filterType.hasStaticShape())
-    return rewriter.notifyMatchFailure(
-        convOp, "expected a static shape for the filter");
-
-  if (!hasAllOneValues(convOp.getDilations()))
-    return rewriter.notifyMatchFailure(convOp,
-                                       "expected all ones for dilations");
-
-  if (!hasAllOneValues(convOp.getStrides()))
-    return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
-
-  auto filterShape = filterType.getShape();
-  int64_t filterF = filterShape[0];
-  int64_t filterH = filterShape[1];
-  int64_t filterW = filterShape[2];
-  int64_t filterC = filterShape[3];
-  auto inputShape = inputType.getShape();
-  int64_t inputN = inputShape[0];
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
-  int64_t inputC = inputShape[3];
-  auto outputShape = outputType.getShape();
-  int64_t outputN = outputShape[0];
-  int64_t outputH = outputShape[1];
-  int64_t outputW = outputShape[2];
-  int64_t outputF = outputShape[3];
-
-  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
-  bool isSupportedFilter = false;
-  if (filterH == filterW && filterH == r)
-    isSupportedFilter = true;
-  if (filterH == r && filterW == 1)
-    isSupportedFilter = true;
-  if (filterH == 1 && filterW == r)
-    isSupportedFilter = true;
-
-  if (!isSupportedFilter)
-    return rewriter.notifyMatchFailure(
-        convOp, "only support filter (r x r), (r x 1) or (1 x r)");
-
-  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5)
-  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
-      F_2_3, F_4_3, F_2_5};
-
-  TransformMapKeyTy key = {m, r};
-  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
-  // If we cannot find the constant transformation matrix, it means we do
-  // not support this configuration yet.
-  if (it == validConfigs.end())
-    return failure();
-
-  // All the criterias are satisfied. We can do Winograd Conv2D.
-  Location loc = convOp.getLoc();
-
-  // For F(m x 1, r x 1), we only need to do left side transform.
-  bool leftTransform = filterH != 1;
-  // For F(1 x m, 1 x r), we only need to do right side transform.
-  bool rightTransform = filterW != 1;
-  int64_t heightM = leftTransform ? m : 1;
-  int64_t widthM = rightTransform ? m : 1;
-  int64_t heightR = leftTransform ? r : 1;
-  int64_t widthR = rightTransform ? r : 1;
-
-  // --- Create operator for filter transform ---
-  Type elementType = filterType.getElementType();
-  int64_t alphaH = heightM + heightR - 1;
-  int64_t alphaW = widthM + widthR - 1;
-  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
-  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
-  auto retType = RankedTensorType::get(
-      {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType);
-  Value retValue =
-      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
-  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
-      loc, retType, filter, retValue, m, r);
-
-  // --- Create operator for input transform ---
-
-  // When input size - (r - 1) is not aligned with output tile size, we need to
-  // pad the input data to create the full tiles as tiling.
-  int64_t alignedInputH = tileH * heightM + (heightR - 1);
-  int64_t alignedInputW = tileW * widthM + (widthR - 1);
-  if (alignedInputH != inputH || alignedInputW != inputW) {
-    auto alignedInputType = RankedTensorType::get(
-        {inputN, alignedInputH, alignedInputW, inputC}, elementType);
-    input = insertToAlignedTensor(rewriter, loc, input, alignedInputType);
-  }
-
-  retType = RankedTensorType::get(
-      {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType);
-  retValue =
-      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
-  auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
-      loc, retType, input, retValue, m, r);
-
-  Value matmulRet =
-      matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
-
-  // --- Create operator for output transform ---
-
-  // When output size is not aligned with output tile size, we need to pad the
-  // output buffer to insert the full tiles after tiling.
-  int64_t alignedOutputH = tileH * heightM;
-  int64_t alignedOutputW = tileW * widthM;
-  bool isOutputUnaligned =
-      ((alignedOutputH != outputH) || (alignedOutputW != outputW));
-  if (isOutputUnaligned) {
-    auto alignedOutputType = RankedTensorType::get(
-        {outputN, alignedOutputH, alignedOutputW, outputF}, elementType);
-    output = insertToAlignedTensor(rewriter, loc, output, alignedOutputType);
-    outputType = alignedOutputType;
-  }
-
-  Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
-      loc, outputType, matmulRet, output, m, r);
-
-  // When output size is not aligned with output tile size, extract the
-  // value from the padded buffer.
-  if (isOutputUnaligned) {
-    transformedOutput = extractFromAlignedTensor(
-        rewriter, loc, transformedOutput,
-        RankedTensorType::get({outputN, outputH, outputW, outputF},
-                              elementType));
-  }
-
-  rewriter.replaceOp(convOp, transformedOutput);
-
-  return transformedOutput.getDefiningOp();
-}
-
-class WinogradConv2DNhwcFhwc final
-    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
-      : OpRewritePattern(context), m(m), r(r) {}
-
-  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
-                                PatternRewriter &rewriter) const override {
-    if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
-      return failure();
-
-    return success();
-  }
-
-private:
-  int64_t m;
-  int64_t r;
-};
-} // end anonymous namespace
-
-//===----------------------------------------------------------------------===//
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
-                                    int64_t r) {
-  MLIRContext *context = patterns.getContext();
-  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
-}
-
-} // end namespace linalg
-} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
deleted file mode 100644
index 6cca3c602d4c0..0000000000000
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ /dev/null
@@ -1,248 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
-
-func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %2 : tensor<2x4x4x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_4x4_3x3
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
-  %0 = tensor.empty() : tensor<2x2x2x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x2x2x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
-  return %2 : tensor<2x2x2x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_2x2_5x5
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x2x2x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x1x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x1x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
-  return %2 : tensor<2x1x4x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_1x4_1x3
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x1x4x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x1x1x6x2x5xf32>) -> tensor<1x1x1x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x1x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x1x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-  return %2 : tensor<2x4x1x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_4x1_3x1
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x4x1x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<1x1x6x1x2x5xf32>) -> tensor<1x1x6x1x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-  %0 = tensor.empty() : tensor<2x8x8x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x8x8x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-  return %2 : tensor<2x8x8x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_aligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:  return %[[S8]] : tensor<2x8x8x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-  %0 = tensor.empty() : tensor<2x9x9x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x9x9x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
-  return %2 : tensor<2x9x9x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x9x9x2xf32>
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:  %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
-// CHECK-NEXT:  %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
-// CHECK-NEXT:  %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK-NEXT:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %2 : tensor<2x4x4x2xf32>
-}
-
-// CHECK-LABEL: conv2d_unsupported_1
-// CHECK: linalg.conv_2d_nhwc_fhwc
-
-// -----
-
-func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %2 : tensor<2x4x4x2xf32>
-}
-
-// CHECK-LABEL: conv2d_unsupported_2
-// CHECK: linalg.conv_2d_nhwc_fhwc
-
-// -----
-
-func.func @conv2d_unsupported_3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
-  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<2x3x3x5xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
-  return %0 : tensor<?x?x?x?xf32>
-}
-
-// CHECK-LABEL: conv2d_unsupported_3
-// CHECK: linalg.conv_2d_nhwc_fhwc
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 12cb46a5968f1..4892fa2f99a7c 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,10 +123,6 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
-  Option<bool> testWinogradConv2D{
-      *this, "test-winograd-conv2d",
-      llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
-      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -211,13 +207,6 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
-static void applyWinogradConv2D(func::FuncOp funcOp) {
-  RewritePatternSet patterns(funcOp.getContext());
-  populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
-  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
-  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -242,8 +231,6 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
-  if (testWinogradConv2D)
-    return applyWinogradConv2D(getOperation());
 }
 
 namespace mlir {

>From 21afe380fd73ac671f9df76a0b0780dbb328fff1 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:24:07 +0100
Subject: [PATCH 12/22] [mlir][linalg] Implement Conv2D using Winograd Conv2D
 algorithm

Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       | 117 ++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 107 ++++++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 334 ++++++++++++++++++
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 193 ++++++++++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  13 +
 7 files changed, 769 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..a9007c8db3078 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,121 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
+def Linalg_WinogradFilterTransformOp :
+    Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
+  let summary = "Winograd filter transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of filter
+    transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
+                       TensorRankOf<[AnyType], [4]>:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs TensorRankOf<[AnyType], [4]>:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $filter `:` type($filter) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradInputTransformOp :
+    Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
+  let summary = "Winograd input transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of input
+    transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
+                       TensorRankOf<[AnyType], [6]>:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs TensorRankOf<[AnyType], [6]>:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradOutputTransformOp :
+    Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
+  let summary = "Winograd output transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of output
+    transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
+                       TensorRankOf<[AnyType], [4]>:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs TensorRankOf<[AnyType], [4]>:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $value `:` type($value) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..835aeaf2ffed3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..1283315f2eaef 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,6 +2734,113 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradFilterTransformOp::verify() {
+  auto filterType = cast<ShapedType>(getFilter().getType());
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t r = getR();
+
+  if (filterH != r && filterH != 1)
+    return failure();
+  if (filterW != r && filterW != 1)
+    return failure();
+  if (filterH == 1 && filterW == 1)
+    return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradInputTransformOp::verify() {
+  auto inputType = cast<ShapedType>(getInput().getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputH = outputShape[0];
+  int64_t outputW = outputShape[1];
+  int64_t outputTileH = outputShape[2];
+  int64_t outputTileW = outputShape[3];
+  int m = getM();
+  int r = getR();
+  bool leftTransform = inputH != 1;
+  bool rightTransform = inputW != 1;
+
+  if (!leftTransform && !rightTransform)
+    return failure();
+
+  if (leftTransform) {
+    int64_t tileH = (inputH - (r - 1)) / m;
+    if (inputH != tileH * m + (r - 1))
+      return failure();
+    if (tileH != outputTileH)
+      return failure();
+    if (outputH != m + r - 1)
+      return failure();
+  }
+
+  if (rightTransform) {
+    int64_t tileW = (inputW - (r - 1)) / m;
+    if (inputW != tileW * m + (r - 1))
+      return failure();
+    if (tileW != outputTileW)
+      return failure();
+    if (outputW != m + r - 1)
+      return failure();
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+  auto valueType = cast<ShapedType>(getValue().getType());
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  int64_t valueTileH = valueShape[2];
+  int64_t valueTileW = valueShape[3];
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int m = getM();
+  int r = getR();
+  bool leftTransform = valueH != 1;
+  bool rightTransform = valueW != 1;
+
+  if (!leftTransform && !rightTransform)
+    return failure();
+
+  if (leftTransform) {
+    if (valueH != m + r - 1)
+      return failure();
+    if (outputH != m * valueTileH)
+      return failure();
+  }
+
+  if (rightTransform) {
+    if (valueW != m + r - 1)
+      return failure();
+    if (outputW != m * valueTileW)
+      return failure();
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
+  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..6b46f9e07abf8
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,334 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+/// This function generates linalg.batch_matmul to multiply input with filter.
+/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
+/// tileH x tileW x H x W data as the 1-dimensional data array. That is to
+/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
+/// way, we can convert 6-dimensional inputs to 3-dimensional representation
+/// that is suitable for linalg.batch_matmul.
+///
+/// Batched matmul will do the matrix multiply with the reduction on channel.
+///
+/// We get
+///
+/// %collapsed_input = tensor.collapse_shape %input
+/// %collapsed_filter = tensor.collapse_shape %filter
+/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+/// %expanded_ret = tensor.expand_shape %ret
+///
+/// After this function, we get return value with data layout
+/// (tileH, tileW, H, W, N, F).
+static Value matrixMultiply(RewriterBase &rewriter, Location loc,
+                            Value transformedFilter, Value transformedInput,
+                            Type outputElementType) {
+  // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
+  auto filterType = cast<ShapedType>(transformedFilter.getType());
+  assert(filterType.hasStaticShape() && "only support static shapes.");
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  Type filterElementType = filterType.getElementType();
+  auto filterReassocType = RankedTensorType::get(
+      {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
+      filterElementType);
+  SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
+  Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, filterReassocType, transformedFilter, filterReassoc);
+
+  // Convert (alphaH, alphaW, tileH, tileW, N, C) to
+  // (alphaH x alphaW, tileH x tileW x N, C) for input.
+  auto inputType = cast<ShapedType>(transformedInput.getType());
+  assert(inputType.hasStaticShape() && "only support static shapes.");
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  Type inputElementType = inputType.getElementType();
+  auto inputReassocType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1],
+       inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
+      inputElementType);
+  SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
+  Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, inputReassocType, transformedInput, inputReassoc);
+
+  // Batched matrix multiply.
+  auto matmulType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1],
+       inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
+      outputElementType);
+  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                outputElementType);
+
+  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+      ValueRange{init});
+
+  // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
+  // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
+  SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
+  auto outputReassocType =
+      RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
+                             inputShape[3], inputShape[4], filterShape[3]},
+                            outputElementType);
+  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+      loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
+  return expandOutput;
+}
+
+/// Create an empty tensor with alignedType and insert the value into the
+/// created empty tensor with aligned size.
+static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc,
+                                   Value value,
+                                   ArrayRef<int64_t> alignedShape) {
+  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+  auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  SmallVector<OpFoldResult, 6> lowIndices(alignedShape.size(), zeroIndex);
+  SmallVector<OpFoldResult, 6> highIndices;
+  for (unsigned i = 0; i < alignedShape.size(); ++i) {
+    highIndices.emplace_back(
+        rewriter.getIndexAttr(alignedShape[i] - valueShape[i]));
+  }
+  auto alignedType = RankedTensorType::get(alignedShape, elementType);
+  Value pad_value = rewriter.create<arith::ConstantOp>(
+      loc, elementType, rewriter.getZeroAttr(elementType));
+  return rewriter.create<tensor::PadOp>(loc, alignedType, value, lowIndices,
+                                        highIndices, pad_value);
+}
+
+/// Extract sub-tensor with extractedType from value.
+static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
+                                      Value value,
+                                      RankedTensorType extractedType) {
+  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+  OpFoldResult oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  ArrayRef<int64_t> extractedShape = extractedType.getShape();
+  SmallVector<OpFoldResult> sizes =
+      getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
+
+  return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
+                                                 offsets, sizes, strides);
+}
+
+/// Utility function to check all values in the attribute are 1.
+static bool hasAllOneValues(DenseIntElementsAttr attr) {
+  return llvm::all_of(
+      attr, [](const APInt &element) { return element.getSExtValue() == 1; });
+}
+
+/// A helper function to convert linalg.conv_2d_nhwc_fhwc to
+/// linalg.winograd_*_transform ops.
+static FailureOr<Operation *>
+winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
+                     int64_t m, int64_t r) {
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+  auto inputType = cast<ShapedType>(input.getType());
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto outputType = cast<ShapedType>(output.getType());
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  if (!hasAllOneValues(convOp.getStrides()))
+    return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
+
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputN = outputShape[0];
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int64_t outputF = outputShape[3];
+
+  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
+  bool isSupportedFilter = false;
+  if (filterH == filterW && filterH == r)
+    isSupportedFilter = true;
+  if (filterH == r && filterW == 1)
+    isSupportedFilter = true;
+  if (filterH == 1 && filterW == r)
+    isSupportedFilter = true;
+
+  if (!isSupportedFilter)
+    return rewriter.notifyMatchFailure(
+        convOp, "only support filter (r x r), (r x 1) or (1 x r)");
+
+  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
+  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+      F_2_3, F_4_3, F_2_5};
+
+  TransformMapKeyTy key = {m, r};
+  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+  // If we cannot find the constant transformation matrix, it means we do
+  // not support this configuration yet.
+  if (it == validConfigs.end())
+    return failure();
+
+  // All the criterias are satisfied. We can do Winograd Conv2D.
+  Location loc = convOp.getLoc();
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  int64_t heightM = leftTransform ? m : 1;
+  int64_t widthM = rightTransform ? m : 1;
+  int64_t heightR = leftTransform ? r : 1;
+  int64_t widthR = rightTransform ? r : 1;
+
+  // --- Create operation for filter transform ---
+  Type filterElementType = filterType.getElementType();
+  int64_t alphaH = heightM + heightR - 1;
+  int64_t alphaW = widthM + widthR - 1;
+  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
+  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
+  auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
+                                       filterElementType);
+  Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
+                                                    filterElementType);
+  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
+      loc, retType, filter, retValue, m, r);
+
+  // --- Create operation for input transform ---
+
+  // When input size - (r - 1) is not aligned with output tile size, we need to
+  // pad the input data to create the full tiles as tiling.
+  Type inputElementType = inputType.getElementType();
+  int64_t alignedInputH = tileH * heightM + (heightR - 1);
+  int64_t alignedInputW = tileW * widthM + (widthR - 1);
+  if (alignedInputH != inputH || alignedInputW != inputW) {
+    input = insertToAlignedTensor(
+        rewriter, loc, input, {inputN, alignedInputH, alignedInputW, inputC});
+  }
+
+  retType = RankedTensorType::get(
+      {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
+  retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
+                                              inputElementType);
+  auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
+      loc, retType, input, retValue, m, r);
+
+  Type outputElementType = outputType.getElementType();
+  Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
+                                   transformedInput, outputElementType);
+
+  // --- Create operation for output transform ---
+
+  // When output size is not aligned with output tile size, we need to pad the
+  // output buffer to insert the full tiles after tiling.
+  int64_t alignedOutputH = tileH * heightM;
+  int64_t alignedOutputW = tileW * widthM;
+  bool isOutputUnaligned =
+      ((alignedOutputH != outputH) || (alignedOutputW != outputW));
+  if (isOutputUnaligned) {
+    auto alignedOutputType = RankedTensorType::get(
+        {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
+    output = insertToAlignedTensor(rewriter, loc, output,
+                                   alignedOutputType.getShape());
+    outputType = alignedOutputType;
+  }
+
+  Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
+      loc, outputType, matmulRet, output, m, r);
+
+  // When output size is not aligned with output tile size, extract the
+  // value from the padded buffer.
+  if (isOutputUnaligned) {
+    transformedOutput = extractFromAlignedTensor(
+        rewriter, loc, transformedOutput,
+        RankedTensorType::get({outputN, outputH, outputW, outputF},
+                              outputElementType));
+  }
+
+  rewriter.replaceOp(convOp, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+/// A rewrite pattern for Winograd Conv2D algorithm.
+class WinogradConv2DNhwcFhwc final
+    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
+      : OpRewritePattern(context), m(m), r(r) {}
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
+      return failure();
+
+    return success();
+  }
+
+private:
+  int64_t m;
+  int64_t r;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
+}
+
+} // end namespace linalg
+} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
new file mode 100644
index 0000000000000..ec11a6ef8fbee
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -0,0 +1,193 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
+
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%out : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %0 : tensor<2x2x2x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_2x2_5x5
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%out : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+  return %0 : tensor<2x1x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_1x4_1x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%out : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %0 : tensor<2x4x1x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_4x1_3x1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_aligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+  return %0 : tensor<2x9x9x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT:  ^bb0
+// CHECK-NEXT:    tensor.yield %[[CST]] : f32
+// CHECK-NEXT:  } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK-NEXT:  %[[PADDED_1:.*]] = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT:  ^bb0
+// CHECK-NEXT:    tensor.yield %[[CST]] : f32
+// CHECK-NEXT:  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf16>, tensor<2x3x3x5xf16>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_type_promotion
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:   return %[[S6]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_1
+// CHECK: linalg.conv_2d_nhwc_fhwc
+
+// -----
+
+func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_2
+// CHECK: linalg.conv_2d_nhwc_fhwc
+
+// -----
+
+func.func @conv2d_unsupported_3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<2x3x3x5xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_3
+// CHECK: linalg.conv_2d_nhwc_fhwc
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4892fa2f99a7c..12cb46a5968f1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,6 +123,10 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
+  Option<bool> testWinogradConv2D{
+      *this, "test-winograd-conv2d",
+      llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -207,6 +211,13 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyWinogradConv2D(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
+  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -231,6 +242,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
+  if (testWinogradConv2D)
+    return applyWinogradConv2D(getOperation());
 }
 
 namespace mlir {

>From 41c86eabf4d2022952e7957d775f82652bcc2f5d Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:49:08 +0100
Subject: [PATCH 13/22] [mlir][linalg] Add transform operator for Winograd
 Conv2D algorithm

Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
---
 .../Linalg/TransformOps/LinalgTransformOps.td | 51 +++++++++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  7 ++
 .../TransformOps/LinalgTransformOps.cpp       | 29 +++++++
 .../Linalg/Transforms/WinogradConv2D.cpp      |  6 ++
 .../Linalg/transform-winograd-conv2d.mlir     | 76 +++++++++++++++++++
 5 files changed, 169 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..5ef56bc97fef1 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+    "structured.winograd_conv2d",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operation into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    #### Return modes:
+
+    This operation produces a silenceable failure if `target` is unsupported.
+    Otherwise, the operation succeeds and returns a handle of the sequence that
+    replaces the original convolution.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       I64Attr:$m,
+                       I64Attr:$r);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 835aeaf2ffed3..da107b66257a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
                                             linalg::BatchMatmulOp op,
                                             bool transposeLHS = true);
 
+/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
+/// F(m x m, r x r). m is the dimension size of output and r is the dimension
+/// size of filter.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..e0f2d00400d63 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,6 +3480,35 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  FailureOr<Operation *> maybeTransformed = failure();
+  bool supported = TypeSwitch<Operation *, bool>(target)
+                       .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+                         maybeTransformed =
+                             winogradConv2D(rewriter, op, getM(), getR());
+                         return true;
+                       })
+                       .Default([&](Operation *op) {
+                         op->emitError("not supported");
+                         return false;
+                       });
+
+  if (supported && failed(maybeTransformed)) {
+    return emitSilenceableError() << "apply Winograd Conv2D failed";
+  }
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 6b46f9e07abf8..843db0c069813 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -324,6 +324,12 @@ class WinogradConv2DNhwcFhwc final
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r) {
+  return winogradConv2DHelper(rewriter, op, m, r);
+}
+
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
new file mode 100644
index 0000000000000..0a2dcc035ebd3
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @conv2d
+// CHECK: linalg.winograd_filter_transform m(4) r(3)
+// CHECK: linalg.winograd_input_transform m(4) r(3)
+// CHECK: linalg.batch_matmul
+// CHECK: linalg.winograd_output_transform m(4) r(3)
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+  return %0 : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK:       linalg.winograd_filter_transform m(4) r(3)
+// CHECK:       tensor.pad
+// CHECK-SAME:  low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK:       linalg.winograd_input_transform m(4) r(3)
+// CHECK:       tensor.pad
+// CHECK-SAME:  low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK:       linalg.winograd_output_transform m(4) r(3)
+
+// -----
+
+func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  // expected-error @+1 {{not supported}}
+  %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
+  return %0 : tensor<2x?x?x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @+1 {{apply Winograd Conv2D failed}}
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}

>From 6c4f432caf655e4c2e9cd27755231473165d3797 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 17:39:49 +0100
Subject: [PATCH 14/22] [mlir][linalg] Decompose winograd operators

Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../Dialect/Linalg/Transforms/Transforms.h    |   3 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 785 ++++++++++++++++++
 .../Linalg/winograd-conv2d-rewrite.mlir       | 105 +++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  11 +
 4 files changed, 904 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index da107b66257a5..bb7ec590faad0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1703,6 +1703,9 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r);
 
+/// Patterns to decompose Winograd operators.
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 843db0c069813..d2dfe366e55d3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -12,7 +12,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -24,6 +27,156 @@ namespace linalg {
 
 namespace {
 
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dimension, is
+//
+// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+//
+// g is filter and d is input data. We need to prepare 6 constant
+// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+//
+// The following tables define these constant transformation matrices for
+// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+   -1,     0,   0,
+ 1./2, -1./2, 1./2,
+ 1./2,  1./2, 1./2,
+    0,     0,    1
+};
+
+constexpr float GT_2x2_3x3[] = {
+   -1,  1./2, 1./2, 0,
+    0, -1./2, 1./2, 0,
+    0,  1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+   -1,    0,   1,   0,
+    0,   -1,   1,   0,
+    0,    1,   1,   0,
+    0,   -1,   0,   1
+};
+
+constexpr float B_2x2_3x3[] = {
+   -1,    0,   0,   0,
+    0,   -1,   1,  -1,
+    1,    1,   1,   0,
+    0,    0,   0,   1
+};
+
+constexpr float AT_2x2_3x3[] = {
+    1,    1,   1,   0,
+    0,   -1,   1,   1
+};
+
+constexpr float A_2x2_3x3[] = {
+    1,    0,
+    1,   -1,
+    1,    1,
+    0,    1
+};
+
+constexpr float G_4x4_3x3[] = {
+     1,     0,     0,
+ -1./3,  1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6,  1./3,
+ 1./12,  1./6,  1./3,
+     0,     0,     1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1,  -1./3, -1./3, 1./12, 1./12, 0,
+ 0,   1./3, -1./3, -1./6,  1./6, 0,
+ 0,  -1./3, -1./3,  1./3,  1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4,     0, -5./16,      0, 1./16,     0,
+    0,  1./4,  -1./4, -1./16, 1./16,     0,
+    0, -1./4,  -1./4,  1./16, 1./16,     0,
+    0,  1./4,  -1./8,  -1./4,  1./8,     0,
+    0, -1./4,  -1./8,   1./4,  1./8,     0,
+    0,  1./4,      0, -5./16,     0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+   1./4,      0,     0,     0,     0,      0,
+      0,   1./4, -1./4,  1./4, -1./4,   1./4,
+ -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
+      0, -1./16, 1./16, -1./4,  1./4, -5./16,
+  1./16,  1./16, 1./16,  1./8,  1./8,      0,
+      0,      0,     0,     0,     0,  1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8,  1./4, 1./4,  1./8, 1./8,    0,
+    0, -1./4, 1./4, -1./4, 1./4,    0,
+    0,  1./4, 1./4,  1./2, 1./2,    0,
+    0, -1./4, 1./4,    -1,    1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+  1./8,     0,    0,     0,
+  1./4, -1./4, 1./4, -1./4,
+  1./4,  1./4, 1./4,  1./4,
+  1./8, -1./4, 1./2,    -1,
+  1./8,  1./4, 1./2,     1,
+     0,     0,    0,  1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+     1,     0,      0,      0,      0,
+  1./6, -1./6,   1./6,  -1./6,   1./6,
+ -1./6, -1./6,  -1./6,  -1./6,  -1./6,
+-4./15, 2./15, -1./15,  1./30, -1./60,
+ 1./60, 1./30,  1./15,  2./15,  4./15,
+     0,     0,      0,      0,      1
+};
+
+constexpr float GT_2x2_5x5[] = {
+   1,  1./6, -1./6, -4./15, 1./60, 0,
+   0, -1./6, -1./6,  2./15, 1./30, 0,
+   0,  1./6, -1./6, -1./15, 1./15, 0,
+   0, -1./6, -1./6,  1./30, 2./15, 0,
+   0,  1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
+    0,   1./8,  1./16,  -5./16,   1./8,    0,
+    0,  -1./8, -5./16,  -1./16,   1./8,    0,
+    0,   1./4,  -1./8,   -1./4,   1./8,    0,
+    0,  -1./8,  -1./4,    1./8,   1./4,    0,
+    0,   1./8,  3./16,   -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+   1./8,      0,      0,     0,     0,      0,
+  3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
+  -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
+ -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
+   1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
+      0,      0,      0,     0,     0,   1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+  1./2,  1, 1,  2, 1,    0,
+     0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2,    0,
+    1,   -1,
+    1,    1,
+    2,   -1,
+    1,    2,
+    0, 1./2
+};
+// clang-format on
+
 using TransformMapKeyTy = std::pair<int, int>;
 
 /// We use F(m, r) to define the size of minimal filtering algorithms.
@@ -37,6 +190,359 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
+
+  return rewriter.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               const_vec));
+}
+
+Value extract2DData(RewriterBase &rewriter, Location loc, Value source,
+                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+                    int64_t srcSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  auto sourceShape = sourceType.getShape();
+  int64_t height = sourceShape[heightIdx];
+  int64_t width = sourceShape[widthIdx];
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
+  offsets[outLoopIdx] = outLoopIndex;
+  offsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
+  sizes[heightIdx] = rewriter.getIndexAttr(height);
+  sizes[widthIdx] = rewriter.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
+  SmallVector<int64_t> targetShape(srcSize, 1);
+  targetShape[heightIdx] = height;
+  targetShape[widthIdx] = width;
+
+  auto targetType = RankedTensorType::get(targetShape, elementType);
+  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, targetType, source, offsets, sizes, strides);
+
+  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, extractFilterOp, extractFilterType);
+
+  return extractFilter;
+}
+
+Value insert2DData(RewriterBase &rewriter, Location loc, Value source,
+                   Value dest, Value outLoopIndex, Value inLoopIndex,
+                   int64_t height, int64_t width, int64_t outLoopIdx,
+                   int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+                   int64_t destSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  SmallVector<int64_t> sliceShape(destSize, 1);
+  sliceShape[heightIdx] = height;
+  sliceShape[widthIdx] = width;
+  auto init = rewriter.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(rewriter, loc,
+                                                                 source, init);
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
+  retOffsets[outLoopIdx] = outLoopIndex;
+  retOffsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
+  retSizes[heightIdx] = rewriter.getIndexAttr(height);
+  retSizes[widthIdx] = rewriter.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
+
+  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, result, dest, retOffsets, retSizes, strides);
+
+  return insertSliceOp;
+}
+
+Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
+  auto type = cast<ShapedType>(data.getType());
+  auto elementType = type.getElementType();
+  auto shape = type.getShape();
+  auto collapseType = RankedTensorType::get(
+      {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
+      elementType);
+  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
+                                                  reassociation);
+}
+
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+// After the transformation, we get
+//
+// scf.for %f = lo_f to hi_f step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract filter<h x w> from filter<f x h x w x c>
+//     %ret = linalg.matmul G, %extracted
+//     %ret = linalg.matmul %ret, GT
+//     %inserted = insert %ret into filter<tile_h x tile_w x h x w x c x f>
+//
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      Value retValue, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GMatrices = {
+          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+      };
+
+  // Map from (m, r) to GT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GTMatrices = {
+          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+      };
+
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+
+  if (filterH != r && filterH != 1)
+    return Value();
+  if (filterW != r && filterW != 1)
+    return Value();
+
+  // Return shape is <H x W x C x F>
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value FIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (F, H, W, C)
+  auto extractFilter = extract2DData(
+      rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+      /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  Value matmulRetValue = extractFilter;
+  if (leftTransform) {
+    // Get constant transform matrix G
+    auto it = GMatrices.find(key);
+    if (it == GMatrices.end())
+      return Value();
+    const TransformMatrix &GMatrix = it->second;
+
+    retRows = GMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType);
+    // Multiply G x g
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix GT
+    auto it = GTMatrices.find(key);
+    if (it == GTMatrices.end())
+      return Value();
+    const TransformMatrix &GTMatrix = it->second;
+
+    auto matmulType =
+        RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType);
+    // Multiply u = (G x g) x GT
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert (H, W) to (1, 1, H, W, C, F)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  int64_t retHeight = leftTransform ? m + r - 1 : 1;
+  int64_t retWidth = rightTransform ? m + r - 1 : 1;
+  auto insertSliceOp = insert2DData(
+      rewriter, loc, matmulRetValue, iterArg, FIter, CIter, retHeight, retWidth,
+      /*outLoopIdx=*/5, /*inLoopIdx=*/4, /*heightIdx=*/2, /*widthIdx=*/3,
+      /*destSize=*/6);
+
+  rewriter.create<scf::YieldOp>(loc, insertSliceOp);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+// This function transforms the input. The data layout of the input is NHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract input<h x w> from input<n x h x w x c>
+//     %ret = linalg.matmul BT, %extracted
+//     %ret = linalg.matmul %ret, B
+//     %inserted = insert %ret into input<h x w x n x c>
+//
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     Value retValue, int64_t m, int64_t r,
+                     bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to BT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BTMatrices = {
+          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+      };
+
+  // Map from (m, r) to B transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BMatrices = {
+          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+      };
+
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (inputH != alphaH && inputH != 1)
+    return Value();
+  if (inputW != alphaW && inputW != 1)
+    return Value();
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (N, H, W, C)
+  auto extractInput = extract2DData(
+      rewriter, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+      /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  Value matmulRetValue = extractInput;
+  if (leftTransform) {
+    // Get constant transform matrix BT
+    auto it = BTMatrices.find(key);
+    if (it == BTMatrices.end())
+      return Value();
+    const TransformMatrix &BTMatrix = it->second;
+
+    retRows = BTMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value BT =
+        create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type());
+    // Multiply BT x d
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix B
+    auto it = BMatrices.find(key);
+    if (it == BMatrices.end())
+      return Value();
+    const TransformMatrix &BMatrix = it->second;
+
+    retCols = BMatrix.cols;
+    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+    Value B =
+        create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type());
+    // Multiply v = (BT x d) x B
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert v
+  // Insert (H, W) to (1, 1, H, W, N, C)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  auto combinedVal = insert2DData(
+      rewriter, loc, matmulRetValue, iterArg, NIter, CIter, retRows, retCols,
+      /*outLoopIdx=*/4, /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3,
+      /*destSize=*/6);
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
 /// This function generates linalg.batch_matmul to multiply input with filter.
 /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
 /// tileH x tileW x H x W data as the 1-dimensional data array. That is to
@@ -108,6 +614,161 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
   return expandOutput;
 }
 
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %f = lo_f to hi_f step 1
+//     %extracted = extract input<h x w> from result<h x w x n x f>
+//     %ret = linalg.matmul AT, %extracted
+//     %ret = linalg.matmul %ret, A
+//     %inserted = insert %ret into ret<n x h x w x f>
+//
+Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
+                      Value output, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to AT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      ATMatrices = {
+          {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+          {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+          {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+      };
+
+  // Map from (m, r) to A transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      AMatrices = {
+          {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+          {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+          {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+      };
+
+  auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
+  auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F
+  int64_t valueH = valueShape[2];
+  int64_t valueW = valueShape[3];
+  int64_t valueN = valueShape[4];
+  int64_t valueF = valueShape[5];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (valueH != alphaH && valueH != 1)
+    return Value();
+  if (valueW != alphaW && valueW != 1)
+    return Value();
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value FIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (1, 1, H, W, N, F)
+  auto extractValue = extract2DData(
+      rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/4,
+      /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, /*srcSize=*/6);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  int64_t leftScalarFactor = 1;
+  int64_t rightScalarFactor = 1;
+  Value matmulRetValue = extractValue;
+  if (leftTransform) {
+    // Get constant transform matrix AT
+    auto it = ATMatrices.find(key);
+    if (it == ATMatrices.end())
+      return Value();
+    const TransformMatrix &ATMatrix = it->second;
+
+    leftScalarFactor = ATMatrix.scalarFactor;
+    retRows = ATMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType);
+    // Multiply AT x m
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix T
+    auto it = AMatrices.find(key);
+    if (it == AMatrices.end())
+      return Value();
+    const TransformMatrix &AMatrix = it->second;
+
+    rightScalarFactor = AMatrix.scalarFactor;
+    auto matmulType =
+        RankedTensorType::get({retRows, AMatrix.cols}, elementType);
+    retCols = AMatrix.cols;
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType);
+    // Multiply y = (AT x m) x A
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Multiply scalar factor.
+  Value scalarFactor = rewriter.create<arith::ConstantOp>(
+      loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+  auto init =
+      rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType);
+
+  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
+  SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()),
+                                       identityAffineMap, identityAffineMap};
+  auto scalarMatrixOp = rewriter.create<linalg::GenericOp>(
+      loc, matmulType, ValueRange{scalarFactor, matmulRetValue},
+      ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2),
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value scalarVal = args[0];
+        Value matrixVal = args[1];
+        Value result = nestedBuilder.create<arith::MulFOp>(nestedLoc, scalarVal,
+                                                           matrixVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
+
+  // Insert slice y
+  // Insert (H, W) to (N, H, W, F)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  Value combinedVal = insert2DData(rewriter, loc, scalarMatrixOp.getResult(0),
+                                   iterArg, NIter, FIter, retRows, retCols,
+                                   /*outLoopIdx=*/0,
+                                   /*inLoopIdx=*/3, /*heightIdx=*/1,
+                                   /*widthIdx=*/2, /*destSize=*/4);
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
 /// Create an empty tensor with alignedType and insert the value into the
 /// created empty tensor with aligned size.
 static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc,
@@ -301,6 +962,123 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   return transformedOutput.getDefiningOp();
 }
 
+FailureOr<Operation *>
+decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradFilterTransformOp op) {
+  Location loc = op.getLoc();
+  Value filter = op.getFilter();
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  Value transformedFilter =
+      filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedFilter)
+    return failure();
+
+  rewriter.replaceOp(op, transformedFilter);
+
+  return transformedFilter.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
+                                      linalg::WinogradInputTransformOp op) {
+  Location loc = op.getLoc();
+  Value input = op.getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = inputH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = inputW != 1;
+  Value transformedInput =
+      inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
+                     op.getR(), leftTransform, rightTransform);
+  if (!transformedInput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedInput);
+
+  return transformedInput.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradOutputTransformOp op) {
+  Location loc = op.getLoc();
+  Value value = op.getValue();
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  int64_t valueH = valueShape[2];
+  int64_t valueW = valueShape[3];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = valueH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = valueW != 1;
+  Value transformedOutput =
+      outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedOutput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+class DecomposeWinogradFilterTransform final
+    : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradFilterTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class DecomposeWinogradInputTransform final
+    : public OpRewritePattern<linalg::WinogradInputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradInputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class DecomposeWinogradOutputTransform final
+    : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradOutputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
 /// A rewrite pattern for Winograd Conv2D algorithm.
 class WinogradConv2DNhwcFhwc final
     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
@@ -336,5 +1114,12 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
   patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
 }
 
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<DecomposeWinogradFilterTransform>(context);
+  patterns.insert<DecomposeWinogradInputTransform>(context);
+  patterns.insert<DecomposeWinogradOutputTransform>(context);
+}
+
 } // end namespace linalg
 } // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
new file mode 100644
index 0000000000000..917d089c1981c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-decompose-winograd-ops | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = tensor.empty() : tensor<1x1x6x6x5x2xf32>
+  %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
+  %4 = tensor.empty() : tensor<1x1x6x6x2x5xf32>
+  %5 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%4 : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %5 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+  %6 = tensor.empty() : tensor<36x2x2xf32>
+  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %expanded = tensor.expand_shape %7 [[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
+  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<1x1x6x6x2x2xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %8 : tensor<2x4x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:   %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:   %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:   %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:   %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG5]], %[[ARG3]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S4]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S1]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x2xf32> to tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) {
+// CHECK-NEXT:      ^bb0(%[[IN:.*]]: f32, %[[IN_9:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:        %[[S17:.*]] = arith.mulf %[[IN]], %[[IN_9]] : f32
+// CHECK-NEXT:        linalg.yield %[[S17]] : f32
+// CHECK-NEXT:      } -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S16:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 12cb46a5968f1..5899f56da7345 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -127,6 +127,9 @@ struct TestLinalgTransforms
       *this, "test-winograd-conv2d",
       llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
       llvm::cl::init(false)};
+  Option<bool> testDecomposeWinogradOps{
+      *this, "test-decompose-winograd-ops",
+      llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
 };
 } // namespace
 
@@ -218,6 +221,12 @@ static void applyWinogradConv2D(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateDecomposeWinogradOpsPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -244,6 +253,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnnecessaryInputs(getOperation());
   if (testWinogradConv2D)
     return applyWinogradConv2D(getOperation());
+  if (testDecomposeWinogradOps)
+    return applyDecomposeWinogradOps(getOperation());
 }
 
 namespace mlir {

>From cdf7647b5327ded458673e33ee087f80114dd35d Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 26 Jun 2024 15:45:07 +0100
Subject: [PATCH 15/22] Address ftynse's comments

---
 .../Linalg/Transforms/WinogradConv2D.cpp      | 664 ++++++++----------
 .../Linalg/winograd-conv2d-rewrite.mlir       | 131 ++--
 2 files changed, 360 insertions(+), 435 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index d2dfe366e55d3..ccd87e9e4b42c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -27,18 +27,18 @@ namespace linalg {
 
 namespace {
 
-// clang-format off
-// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
-// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
-// m is the output dimension and r is the filter dimension, is
-//
-// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
-//
-// g is filter and d is input data. We need to prepare 6 constant
-// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
-//
-// The following tables define these constant transformation matrices for
-// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+/// clang-format off
+/// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+/// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+/// m is the output dimension and r is the filter dimension, is
+///
+/// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+///
+/// g is filter and d is input data. We need to prepare 6 constant
+/// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+///
+/// The following tables define these constant transformation matrices for
+/// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
 constexpr float G_2x2_3x3[] = {
    -1,     0,   0,
  1./2, -1./2, 1./2,
@@ -190,6 +190,7 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
+/// Structure to keep information of constant transform matrices.
 struct TransformMatrix {
   TransformMatrix(const float *table, int64_t rows, int64_t cols,
                   int64_t scalarFactor = 1)
@@ -201,18 +202,20 @@ struct TransformMatrix {
   int64_t scalarFactor;
 };
 
-Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+/// Utility function to convert constant array to arith.constant Value.
+Value create2DTransformMatrix(OpBuilder &builder, Location loc,
                               TransformMatrix transform, Type type) {
-  ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
+  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
 
-  return rewriter.create<arith::ConstantOp>(
+  return builder.create<arith::ConstantOp>(
       loc, DenseFPElementsAttr::get(
                RankedTensorType::get(
                    SmallVector<int64_t>{transform.rows, transform.cols}, type),
-               const_vec));
+               constVec));
 }
 
-Value extract2DData(RewriterBase &rewriter, Location loc, Value source,
+/// Extract height x width data from 4D or 6D tensors.
+Value extract2DData(OpBuilder &builder, Location loc, Value source,
                     Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
                     int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
                     int64_t srcSize) {
@@ -222,84 +225,72 @@ Value extract2DData(RewriterBase &rewriter, Location loc, Value source,
   int64_t height = sourceShape[heightIdx];
   int64_t width = sourceShape[widthIdx];
 
-  auto zeroIndex = rewriter.getIndexAttr(0);
-  auto oneIndex = rewriter.getIndexAttr(1);
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
   SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
   offsets[outLoopIdx] = outLoopIndex;
   offsets[inLoopIdx] = inLoopIndex;
   SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
-  sizes[heightIdx] = rewriter.getIndexAttr(height);
-  sizes[widthIdx] = rewriter.getIndexAttr(width);
+  sizes[heightIdx] = builder.getIndexAttr(height);
+  sizes[widthIdx] = builder.getIndexAttr(width);
   SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
   SmallVector<int64_t> targetShape(srcSize, 1);
   targetShape[heightIdx] = height;
   targetShape[widthIdx] = width;
 
   auto targetType = RankedTensorType::get(targetShape, elementType);
-  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
       loc, targetType, source, offsets, sizes, strides);
 
   auto extractFilterType = RankedTensorType::get({height, width}, elementType);
   auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, extractFilterOp, extractFilterType);
+      builder, loc, extractFilterOp, extractFilterType);
 
   return extractFilter;
 }
 
-Value insert2DData(RewriterBase &rewriter, Location loc, Value source,
-                   Value dest, Value outLoopIndex, Value inLoopIndex,
-                   int64_t height, int64_t width, int64_t outLoopIdx,
-                   int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
-                   int64_t destSize) {
+/// Insert transformed height x width data to 4D or 6D tensors which it is
+/// extracted from.
+Value insert2DData(OpBuilder &builder, Location loc, Value source, Value dest,
+                   Value outLoopIndex, Value inLoopIndex, int64_t height,
+                   int64_t width, int64_t outLoopIdx, int64_t inLoopIdx,
+                   int64_t heightIdx, int64_t widthIdx, int64_t destSize) {
   auto sourceType = cast<ShapedType>(source.getType());
   Type elementType = sourceType.getElementType();
   SmallVector<int64_t> sliceShape(destSize, 1);
   sliceShape[heightIdx] = height;
   sliceShape[widthIdx] = width;
-  auto init = rewriter.create<tensor::EmptyOp>(loc, sliceShape, elementType);
-  auto result = tensor::createCanonicalRankReducingInsertSliceOp(rewriter, loc,
+  auto init = builder.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(builder, loc,
                                                                  source, init);
 
-  auto zeroIndex = rewriter.getIndexAttr(0);
-  auto oneIndex = rewriter.getIndexAttr(1);
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
   SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
   retOffsets[outLoopIdx] = outLoopIndex;
   retOffsets[inLoopIdx] = inLoopIndex;
   SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
-  retSizes[heightIdx] = rewriter.getIndexAttr(height);
-  retSizes[widthIdx] = rewriter.getIndexAttr(width);
+  retSizes[heightIdx] = builder.getIndexAttr(height);
+  retSizes[widthIdx] = builder.getIndexAttr(width);
   SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
 
-  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
       loc, result, dest, retOffsets, retSizes, strides);
 
   return insertSliceOp;
 }
 
-Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
-  auto type = cast<ShapedType>(data.getType());
-  auto elementType = type.getElementType();
-  auto shape = type.getShape();
-  auto collapseType = RankedTensorType::get(
-      {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
-      elementType);
-  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
-  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
-                                                  reassociation);
-}
-
-// This function transforms the filter. The data layout of the filter is FHWC.
-// The transformation matrix is 2-dimension. We need to extract H x W from
-// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
-// After the transformation, we get
-//
-// scf.for %f = lo_f to hi_f step 1
-//   scf.for %c = lo_c to hi_c step 1
-//     %extracted = extract filter<h x w> from filter<f x h x w x c>
-//     %ret = linalg.matmul G, %extracted
-//     %ret = linalg.matmul %ret, GT
-//     %inserted = insert %ret into filter<tile_h x tile_w x h x w x c x f>
-//
+/// This function transforms the filter. The data layout of the filter is FHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+/// After the transformation, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract filter<h x w> from filter<f x h x w x c>
+///     %ret = linalg.matmul G, %extracted
+///     %ret = linalg.matmul %ret, GT
+///     %inserted = insert %ret into filter<h x w x c x f>
 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
                       Value retValue, int64_t m, int64_t r,
                       bool leftTransform = true, bool rightTransform = true) {
@@ -332,100 +323,90 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
   if (filterW != r && filterW != 1)
     return Value();
 
-  // Return shape is <H x W x C x F>
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value FIter = ivs[0];
+    Value CIter = ivs[1];
+
+    // Extract (H, W) from (F, H, W, C).
+    auto extractFilter = extract2DData(
+        builder, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+        /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    Value matmulRetValue = extractFilter;
+    if (leftTransform) {
+      // Get constant transform matrix G.
+      auto it = GMatrices.find(key);
+      if (it == GMatrices.end())
+        return {};
+      const TransformMatrix &GMatrix = it->second;
+
+      retRows = GMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+      // Multiply G x g.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix GT.
+      auto it = GTMatrices.find(key);
+      if (it == GTMatrices.end())
+        return {};
+      const TransformMatrix &GTMatrix = it->second;
+
+      auto matmulType =
+          RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+      // Multiply u = (G x g) x GT.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Insert (H, W) to (H, W, C, F).
+    int64_t retHeight = leftTransform ? m + r - 1 : 1;
+    int64_t retWidth = rightTransform ? m + r - 1 : 1;
+    auto insertSliceOp = insert2DData(builder, loc, matmulRetValue, args[0],
+                                      FIter, CIter, retHeight, retWidth,
+                                      /*outLoopIdx=*/3, /*inLoopIdx=*/2,
+                                      /*heightIdx=*/0, /*widthIdx=*/1,
+                                      /*destSize=*/4);
+
+    return {insertSliceOp};
+  };
+
   auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
   auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-  auto outerForOp =
-      rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
-  Block *outerForBody = outerForOp.getBody();
-  rewriter.setInsertionPointToStart(outerForBody);
-  Value FIter = outerForBody->getArgument(0);
-
-  auto innerForOp = rewriter.create<scf::ForOp>(
-      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
-  Block *innerForBody = innerForOp.getBody();
-  rewriter.setInsertionPointToStart(innerForBody);
-  Value CIter = innerForBody->getArgument(0);
-
-  // Extract (H, W) from (F, H, W, C)
-  auto extractFilter = extract2DData(
-      rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
-      /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
-
-  TransformMapKeyTy key = {m, r};
-  int64_t retRows = 1;
-  Value matmulRetValue = extractFilter;
-  if (leftTransform) {
-    // Get constant transform matrix G
-    auto it = GMatrices.find(key);
-    if (it == GMatrices.end())
-      return Value();
-    const TransformMatrix &GMatrix = it->second;
-
-    retRows = GMatrix.rows;
-    auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-
-    Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType);
-    // Multiply G x g
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  if (rightTransform) {
-    // Get constant transform matrix GT
-    auto it = GTMatrices.find(key);
-    if (it == GTMatrices.end())
-      return Value();
-    const TransformMatrix &GTMatrix = it->second;
-
-    auto matmulType =
-        RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-
-    Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType);
-    // Multiply u = (G x g) x GT
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  // Insert (H, W) to (1, 1, H, W, C, F)
-  Value iterArg = innerForOp.getRegionIterArgs()[0];
-  int64_t retHeight = leftTransform ? m + r - 1 : 1;
-  int64_t retWidth = rightTransform ? m + r - 1 : 1;
-  auto insertSliceOp = insert2DData(
-      rewriter, loc, matmulRetValue, iterArg, FIter, CIter, retHeight, retWidth,
-      /*outLoopIdx=*/5, /*inLoopIdx=*/4, /*heightIdx=*/2, /*widthIdx=*/3,
-      /*destSize=*/6);
-
-  rewriter.create<scf::YieldOp>(loc, insertSliceOp);
-
-  rewriter.setInsertionPointToEnd(outerForBody);
-  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
-
-  rewriter.setInsertionPointAfter(outerForOp);
-
-  return outerForOp.getResult(0);
+  scf::LoopNest loops = scf::buildLoopNest(
+      rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
+      {oneStep, oneStep}, {retValue}, buildBody);
+  return loops.results[0];
 }
 
-// This function transforms the input. The data layout of the input is NHWC.
-// The transformation matrix is 2-dimension. We need to extract H x W from
-// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
-// After the transformation, we get
-//
-// scf.for %n = lo_n to hi_n step 1
-//   scf.for %c = lo_c to hi_c step 1
-//     %extracted = extract input<h x w> from input<n x h x w x c>
-//     %ret = linalg.matmul BT, %extracted
-//     %ret = linalg.matmul %ret, B
-//     %inserted = insert %ret into input<h x w x n x c>
-//
+/// This function transforms the input. The data layout of the input is NHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+/// After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract input<h x w> from input<n x h x w x c>
+///     %ret = linalg.matmul BT, %extracted
+///     %ret = linalg.matmul %ret, B
+///     %inserted = insert %ret into input<h x w x n x c>
 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
                      Value retValue, int64_t m, int64_t r,
                      bool leftTransform = true, bool rightTransform = true) {
@@ -460,87 +441,76 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
   if (inputW != alphaW && inputW != 1)
     return Value();
 
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value NIter = ivs[0];
+    Value CIter = ivs[1];
+
+    // Extract (H, W) from (N, H, W, C).
+    auto extractInput = extract2DData(
+        builder, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+        /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    int64_t retCols = 1;
+    Value matmulRetValue = extractInput;
+    if (leftTransform) {
+      // Get constant transform matrix BT.
+      auto it = BTMatrices.find(key);
+      if (it == BTMatrices.end())
+        return {};
+      const TransformMatrix &BTMatrix = it->second;
+
+      retRows = BTMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value BT =
+          create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+      // Multiply BT x d.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix B.
+      auto it = BMatrices.find(key);
+      if (it == BMatrices.end())
+        return {};
+      const TransformMatrix &BMatrix = it->second;
+
+      retCols = BMatrix.cols;
+      auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+      Value B =
+          create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+      // Multiply v = (BT x d) x B.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Insert (H, W) to (H, W, 1, 1, N, C).
+    auto combinedVal = insert2DData(
+        builder, loc, matmulRetValue, args[0], NIter, CIter, retRows, retCols,
+        /*outLoopIdx=*/4, /*inLoopIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1,
+        /*destSize=*/6);
+
+    return {combinedVal};
+  };
+
   auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
   auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-
-  auto outerForOp =
-      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, retValue);
-  Block *outerForBody = outerForOp.getBody();
-  rewriter.setInsertionPointToStart(outerForBody);
-  Value NIter = outerForBody->getArgument(0);
-
-  auto innerForOp = rewriter.create<scf::ForOp>(
-      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
-  Block *innerForBody = innerForOp.getBody();
-  rewriter.setInsertionPointToStart(innerForBody);
-  Value CIter = innerForBody->getArgument(0);
-
-  // Extract (H, W) from (N, H, W, C)
-  auto extractInput = extract2DData(
-      rewriter, loc, input, NIter, CIter, /*outLoopIdx=*/0,
-      /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
-
-  TransformMapKeyTy key = {m, r};
-  int64_t retRows = 1;
-  int64_t retCols = 1;
-  Value matmulRetValue = extractInput;
-  if (leftTransform) {
-    // Get constant transform matrix BT
-    auto it = BTMatrices.find(key);
-    if (it == BTMatrices.end())
-      return Value();
-    const TransformMatrix &BTMatrix = it->second;
-
-    retRows = BTMatrix.rows;
-    auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-
-    Value BT =
-        create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type());
-    // Multiply BT x d
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  if (rightTransform) {
-    // Get constant transform matrix B
-    auto it = BMatrices.find(key);
-    if (it == BMatrices.end())
-      return Value();
-    const TransformMatrix &BMatrix = it->second;
-
-    retCols = BMatrix.cols;
-    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-    Value B =
-        create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type());
-    // Multiply v = (BT x d) x B
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  // Insert v
-  // Insert (H, W) to (1, 1, H, W, N, C)
-  Value iterArg = innerForOp.getRegionIterArgs()[0];
-  auto combinedVal = insert2DData(
-      rewriter, loc, matmulRetValue, iterArg, NIter, CIter, retRows, retCols,
-      /*outLoopIdx=*/4, /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3,
-      /*destSize=*/6);
-
-  rewriter.create<scf::YieldOp>(loc, combinedVal);
-
-  rewriter.setInsertionPointToEnd(outerForBody);
-  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
-
-  rewriter.setInsertionPointAfter(outerForOp);
-
-  return outerForOp.getResult(0);
+  scf::LoopNest loops = scf::buildLoopNest(
+      rewriter, loc, {zeroIdx, zeroIdx}, {nUpperBound, cUpperBound},
+      {oneStep, oneStep}, {retValue}, buildBody);
+  return loops.results[0];
 }
 
 /// This function generates linalg.batch_matmul to multiply input with filter.
@@ -614,18 +584,17 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
   return expandOutput;
 }
 
-// This function transforms the output. The data layout of the output is HWNF.
-// The transformation matrix is 2-dimension. We need to extract H x W from
-// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
-// After the transformation, we get
-//
-// scf.for %n = lo_n to hi_n step 1
-//   scf.for %f = lo_f to hi_f step 1
-//     %extracted = extract input<h x w> from result<h x w x n x f>
-//     %ret = linalg.matmul AT, %extracted
-//     %ret = linalg.matmul %ret, A
-//     %inserted = insert %ret into ret<n x h x w x f>
-//
+/// This function transforms the output. The data layout of the output is HWNF.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+/// After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %f = lo_f to hi_f step 1
+///     %extracted = extract input<h x w> from result<h x w x n x f>
+///     %ret = linalg.matmul AT, %extracted
+///     %ret = linalg.matmul %ret, A
+///     %inserted = insert %ret into ret<n x h x w x f>
 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
                       Value output, int64_t m, int64_t r,
                       bool leftTransform = true, bool rightTransform = true) {
@@ -647,9 +616,9 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
 
   auto valueType = cast<ShapedType>(value.getType());
   Type elementType = valueType.getElementType();
-  auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F
-  int64_t valueH = valueShape[2];
-  int64_t valueW = valueShape[3];
+  auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
   int64_t valueN = valueShape[4];
   int64_t valueF = valueShape[5];
   int64_t alphaH = leftTransform ? m + r - 1 : 1;
@@ -660,113 +629,93 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
   if (valueW != alphaW && valueW != 1)
     return Value();
 
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value NIter = ivs[0];
+    Value FIter = ivs[1];
+
+    // Extract (H, W) from (H, W, 1, 1, N, F).
+    auto extractValue = extract2DData(
+        builder, loc, value, NIter, FIter, /*outLoopIdx=*/4,
+        /*inLoopIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1, /*srcSize=*/6);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    int64_t retCols = 1;
+    int64_t leftScalarFactor = 1;
+    int64_t rightScalarFactor = 1;
+    Value matmulRetValue = extractValue;
+    if (leftTransform) {
+      // Get constant transform matrix AT.
+      auto it = ATMatrices.find(key);
+      if (it == ATMatrices.end())
+        return {};
+      const TransformMatrix &ATMatrix = it->second;
+
+      leftScalarFactor = ATMatrix.scalarFactor;
+      retRows = ATMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
+      // Multiply AT x m.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix T.
+      auto it = AMatrices.find(key);
+      if (it == AMatrices.end())
+        return {};
+      const TransformMatrix &AMatrix = it->second;
+
+      rightScalarFactor = AMatrix.scalarFactor;
+      auto matmulType =
+          RankedTensorType::get({retRows, AMatrix.cols}, elementType);
+      retCols = AMatrix.cols;
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
+      // Multiply y = (AT x m) x A.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Multiply scalar factor.
+    Value scalarFactor = builder.create<arith::ConstantOp>(
+        loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+    auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                elementType);
+    Value broadcastedScalar =
+        builder.create<tensor::FromElementsOp>(loc, matmulType, scalarFactor);
+    auto scaledMatmul = builder.create<linalg::MulOp>(
+        loc, matmulType, ValueRange{broadcastedScalar, matmulRetValue},
+        ValueRange{init});
+
+    // Insert (H, W) to (N, H, W, F).
+    Value combinedVal = insert2DData(builder, loc, scaledMatmul.getResult(0),
+                                     args[0], NIter, FIter, retRows, retCols,
+                                     /*outLoopIdx=*/0,
+                                     /*inLoopIdx=*/3, /*heightIdx=*/1,
+                                     /*widthIdx=*/2, /*destSize=*/4);
+
+    return {combinedVal};
+  };
+
   auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
   auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-
-  auto outerForOp =
-      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output);
-  Block *outerForBody = outerForOp.getBody();
-  rewriter.setInsertionPointToStart(outerForBody);
-  Value NIter = outerForBody->getArgument(0);
-
-  auto innerForOp = rewriter.create<scf::ForOp>(
-      loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
-  Block *innerForBody = innerForOp.getBody();
-  rewriter.setInsertionPointToStart(innerForBody);
-  Value FIter = innerForBody->getArgument(0);
-
-  // Extract (H, W) from (1, 1, H, W, N, F)
-  auto extractValue = extract2DData(
-      rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/4,
-      /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, /*srcSize=*/6);
-
-  TransformMapKeyTy key = {m, r};
-  int64_t retRows = 1;
-  int64_t retCols = 1;
-  int64_t leftScalarFactor = 1;
-  int64_t rightScalarFactor = 1;
-  Value matmulRetValue = extractValue;
-  if (leftTransform) {
-    // Get constant transform matrix AT
-    auto it = ATMatrices.find(key);
-    if (it == ATMatrices.end())
-      return Value();
-    const TransformMatrix &ATMatrix = it->second;
-
-    leftScalarFactor = ATMatrix.scalarFactor;
-    retRows = ATMatrix.rows;
-    auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-
-    Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType);
-    // Multiply AT x m
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  if (rightTransform) {
-    // Get constant transform matrix T
-    auto it = AMatrices.find(key);
-    if (it == AMatrices.end())
-      return Value();
-    const TransformMatrix &AMatrix = it->second;
-
-    rightScalarFactor = AMatrix.scalarFactor;
-    auto matmulType =
-        RankedTensorType::get({retRows, AMatrix.cols}, elementType);
-    retCols = AMatrix.cols;
-    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                 elementType);
-
-    Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType);
-    // Multiply y = (AT x m) x A
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
-    matmulRetValue = matmulOp.getResult(0);
-  }
-
-  // Multiply scalar factor.
-  Value scalarFactor = rewriter.create<arith::ConstantOp>(
-      loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
-  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
-  auto init =
-      rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType);
-
-  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
-  SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()),
-                                       identityAffineMap, identityAffineMap};
-  auto scalarMatrixOp = rewriter.create<linalg::GenericOp>(
-      loc, matmulType, ValueRange{scalarFactor, matmulRetValue},
-      ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2),
-      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        Value scalarVal = args[0];
-        Value matrixVal = args[1];
-        Value result = nestedBuilder.create<arith::MulFOp>(nestedLoc, scalarVal,
-                                                           matrixVal);
-        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
-      });
-
-  // Insert slice y
-  // Insert (H, W) to (N, H, W, F)
-  Value iterArg = innerForOp.getRegionIterArgs()[0];
-  Value combinedVal = insert2DData(rewriter, loc, scalarMatrixOp.getResult(0),
-                                   iterArg, NIter, FIter, retRows, retCols,
-                                   /*outLoopIdx=*/0,
-                                   /*inLoopIdx=*/3, /*heightIdx=*/1,
-                                   /*widthIdx=*/2, /*destSize=*/4);
-
-  rewriter.create<scf::YieldOp>(loc, combinedVal);
-
-  rewriter.setInsertionPointToEnd(outerForBody);
-  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
-
-  rewriter.setInsertionPointAfter(outerForOp);
-
-  return outerForOp.getResult(0);
+  scf::LoopNest loops = scf::buildLoopNest(
+      rewriter, loc, {zeroIdx, zeroIdx}, {nUpperBound, fUpperBound},
+      {oneStep, oneStep}, {output}, buildBody);
+  return loops.results[0];
 }
 
 /// Create an empty tensor with alignedType and insert the value into the
@@ -962,6 +911,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   return transformedOutput.getDefiningOp();
 }
 
+/// A helper function to decompose linalg.winograd_filter_transform.
 FailureOr<Operation *>
 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
                                        linalg::WinogradFilterTransformOp op) {
@@ -987,6 +937,7 @@ decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
   return transformedFilter.getDefiningOp();
 }
 
+/// A helper function to decompose linalg.winograd_input_transform.
 FailureOr<Operation *>
 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
                                       linalg::WinogradInputTransformOp op) {
@@ -1012,6 +963,7 @@ decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
   return transformedInput.getDefiningOp();
 }
 
+/// A helper function to decompose linalg.winograd_output_transform.
 FailureOr<Operation *>
 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
                                        linalg::WinogradOutputTransformOp op) {
@@ -1019,8 +971,8 @@ decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
   Value value = op.getValue();
   auto valueType = cast<ShapedType>(value.getType());
   auto valueShape = valueType.getShape();
-  int64_t valueH = valueShape[2];
-  int64_t valueW = valueShape[3];
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
 
   // For F(m x 1, r x 1), we only need to do left side transform.
   bool leftTransform = valueH != 1;
@@ -1037,6 +989,7 @@ decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
   return transformedOutput.getDefiningOp();
 }
 
+/// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
 class DecomposeWinogradFilterTransform final
     : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
 public:
@@ -1044,13 +997,11 @@ class DecomposeWinogradFilterTransform final
 
   LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(decomposeWinogradFilterTransformHelper(rewriter, op)))
-      return failure();
-
-    return success();
+    return decomposeWinogradFilterTransformHelper(rewriter, op);
   }
 };
 
+/// A rewrite pattern to decompose linalg.winograd_input_transform operations.
 class DecomposeWinogradInputTransform final
     : public OpRewritePattern<linalg::WinogradInputTransformOp> {
 public:
@@ -1058,13 +1009,11 @@ class DecomposeWinogradInputTransform final
 
   LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(decomposeWinogradInputTransformHelper(rewriter, op)))
-      return failure();
-
-    return success();
+    return decomposeWinogradInputTransformHelper(rewriter, op);
   }
 };
 
+/// A rewrite pattern to decompose linalg.winograd_output_transform operations.
 class DecomposeWinogradOutputTransform final
     : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
 public:
@@ -1072,10 +1021,7 @@ class DecomposeWinogradOutputTransform final
 
   LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(decomposeWinogradOutputTransformHelper(rewriter, op)))
-      return failure();
-
-    return success();
+    return decomposeWinogradOutputTransformHelper(rewriter, op);
   }
 };
 
@@ -1116,9 +1062,9 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
 
 void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
-  patterns.insert<DecomposeWinogradFilterTransform>(context);
-  patterns.insert<DecomposeWinogradInputTransform>(context);
-  patterns.insert<DecomposeWinogradOutputTransform>(context);
+  patterns
+      .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
+              DecomposeWinogradOutputTransform>(context);
 }
 
 } // end namespace linalg
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 917d089c1981c..73fe78f065fc6 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -1,33 +1,22 @@
 // RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-decompose-winograd-ops | FileCheck %s
 
-#map = affine_map<(d0, d1, d2, d3) -> (0)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-  %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
-  %4 = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-  %5 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%4 : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
-  %collapsed = tensor.collapse_shape %3 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-  %collapsed_0 = tensor.collapse_shape %5 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
-  %6 = tensor.empty() : tensor<36x2x2xf32>
-  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-  %expanded = tensor.expand_shape %7 [[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<1x1x6x6x2x2xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %8 : tensor<2x4x4x2xf32>
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<6x6x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %2 = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%2 : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+  %4 = tensor.empty() : tensor<36x2x2xf32>
+  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %6 : tensor<2x4x4x2xf32>
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
-// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func.func @conv2d_4x4_3x3
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK-DAG:   %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG:   %[[CST:.*]] = arith.constant dense<1.024000e+03> : tensor<4x4xf32>
 // CHECK-DAG:   %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
 // CHECK-DAG:   %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
 // CHECK-DAG:   %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
@@ -38,68 +27,58 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
 // CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
 // CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<1x1x6x6x5x2xf32>) {
-// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-DAG:   %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:    %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
 // CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
 // CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
-// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x3xf32>
-// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
-// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG5]], %[[ARG3]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:      %[[S8:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:      %[[S9:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S8]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[S12]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x5x2xf32>
 // CHECK-NEXT:    }
-// CHECK-NEXT:    scf.yield %[[S9]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:    scf.yield %[[S7]] : tensor<6x6x5x2xf32>
 // CHECK-NEXT:  }
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S4]]) -> (tensor<1x1x6x6x2x5xf32>) {
-// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK-NEXT:    %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x1x1x2x5xf32>) {
 // CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
 // CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S8:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S9:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S8]] : tensor<6x6xf32>) -> tensor<6x6xf32>
 // CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[S12]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x1x1x2x5xf32>
 // CHECK-NEXT:    }
-// CHECK-NEXT:    scf.yield %[[S9]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:    scf.yield %[[S7]] : tensor<6x6x1x1x2x5xf32>
 // CHECK-NEXT:  }
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S1]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x2xf32> to tensor<1x1x6x6x1x1xf32>
-// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> to tensor<6x6xf32>
-// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
-// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:  %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:    %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6x1x1x1x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S8:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:      %[[S9:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S8]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S10]] : tensor<4x4xf32>) -> tensor<4x4xf32>
 // CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:      %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) {
-// CHECK-NEXT:      ^bb0(%[[IN:.*]]: f32, %[[IN_9:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:        %[[S17:.*]] = arith.mulf %[[IN]], %[[IN_9]] : f32
-// CHECK-NEXT:        linalg.yield %[[S17]] : f32
-// CHECK-NEXT:      } -> tensor<4x4xf32>
-// CHECK-NEXT:      %[[S16:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.mul ins(%cst, %[[S11]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
 // CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
 // CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<2x4x4x2xf32>
 // CHECK-NEXT:    }
-// CHECK-NEXT:    scf.yield %[[S9]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:    scf.yield %[[S7]] : tensor<2x4x4x2xf32>
 // CHECK-NEXT:  }
-// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
-// CHECK-NEXT:}
+// CHECK-NEXT:  return %[[S6]] : tensor<2x4x4x2xf32>

>From a93529db926dfabbc9354888efc2917bd2330d62 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 27 Jun 2024 10:04:50 +0100
Subject: [PATCH 16/22] fix failed test

---
 mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 6 +++++-
 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir     | 2 +-
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e0f2d00400d63..6f03d71fd0e1f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3497,10 +3497,14 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
                          return true;
                        })
                        .Default([&](Operation *op) {
-                         op->emitError("not supported");
                          return false;
                        });
 
+  if (!supported) {
+    return emitSilenceableError()
+           << "this operation is not supported to convert to Winograd Conv2D";
+  }
+
   if (supported && failed(maybeTransformed)) {
     return emitSilenceableError() << "apply Winograd Conv2D failed";
   }
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
index 0a2dcc035ebd3..c10e0ccebfd7c 100644
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -46,7 +46,6 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-  // expected-error @+1 {{not supported}}
   %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
   return %0 : tensor<2x8x8x2xf32>
 }
@@ -54,6 +53,7 @@ func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @+1 {{this operation is not supported to convert to Winograd Conv2D}}
     %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }

>From 549029b5feae1bbc4aee3b9fc94134d5127e514b Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 27 Jun 2024 10:04:50 +0100
Subject: [PATCH 17/22] fix failed test

---
 mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 6 +++++-
 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir     | 2 +-
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e0f2d00400d63..6f03d71fd0e1f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3497,10 +3497,14 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
                          return true;
                        })
                        .Default([&](Operation *op) {
-                         op->emitError("not supported");
                          return false;
                        });
 
+  if (!supported) {
+    return emitSilenceableError()
+           << "this operation is not supported to convert to Winograd Conv2D";
+  }
+
   if (supported && failed(maybeTransformed)) {
     return emitSilenceableError() << "apply Winograd Conv2D failed";
   }
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
index 0a2dcc035ebd3..c10e0ccebfd7c 100644
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -46,7 +46,6 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-  // expected-error @+1 {{not supported}}
   %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
   return %0 : tensor<2x8x8x2xf32>
 }
@@ -54,6 +53,7 @@ func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @+1 {{this operation is not supported to convert to Winograd Conv2D}}
     %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }

>From b08accd6b17c5fc102ac823609d727c6bbde6659 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 1 Jul 2024 00:16:45 +0100
Subject: [PATCH 18/22] correct the way to broadcast a scalar value

---
 .../Linalg/Transforms/WinogradConv2D.cpp      | 48 ++++++++++++++-----
 .../Linalg/winograd-conv2d-rewrite.mlir       | 10 +++-
 2 files changed, 43 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index ccd87e9e4b42c..54cbcc3f9cca2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -686,21 +686,43 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
       matmulRetValue = matmulOp.getResult(0);
     }
 
-    // Multiply scalar factor.
-    Value scalarFactor = builder.create<arith::ConstantOp>(
-        loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
-    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
-    auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                elementType);
-    Value broadcastedScalar =
-        builder.create<tensor::FromElementsOp>(loc, matmulType, scalarFactor);
-    auto scaledMatmul = builder.create<linalg::MulOp>(
-        loc, matmulType, ValueRange{broadcastedScalar, matmulRetValue},
-        ValueRange{init});
+    if (leftScalarFactor * rightScalarFactor != 1) {
+      // Multiply scalar factor.
+      Value scalarFactor = builder.create<arith::ConstantOp>(
+          loc,
+          FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+      auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
+      SmallVector<AffineMap> affineMaps = {
+          AffineMap::get(2, 0, init.getContext()), identityAffineMap};
+      auto broadcastedScalar =
+          rewriter
+              .create<linalg::GenericOp>(
+                  loc, matmulType, ValueRange{scalarFactor}, ValueRange{init},
+                  affineMaps,
+                  llvm::ArrayRef<utils::IteratorType>{
+                      utils::IteratorType::parallel,
+                      utils::IteratorType::parallel},
+                  [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                      ValueRange args) {
+                    nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+                  })
+              .getResult(0);
+
+      matmulRetValue = builder
+                           .create<linalg::MulOp>(
+                               loc, matmulType,
+                               ValueRange{broadcastedScalar, matmulRetValue},
+                               ValueRange{init})
+                           .getResult(0);
+    }
 
     // Insert (H, W) to (N, H, W, F).
-    Value combinedVal = insert2DData(builder, loc, scaledMatmul.getResult(0),
-                                     args[0], NIter, FIter, retRows, retCols,
+    Value combinedVal = insert2DData(builder, loc, matmulRetValue, args[0],
+                                     NIter, FIter, retRows, retCols,
                                      /*outLoopIdx=*/0,
                                      /*inLoopIdx=*/3, /*heightIdx=*/1,
                                      /*widthIdx=*/2, /*destSize=*/4);
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 73fe78f065fc6..cc5562ff22c99 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -14,9 +14,11 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
   return %6 : tensor<2x4x4x2xf32>
 }
 
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func.func @conv2d_4x4_3x3
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK-DAG:   %[[CST:.*]] = arith.constant dense<1.024000e+03> : tensor<4x4xf32>
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 1.024000e+03 : f32
 // CHECK-DAG:   %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
 // CHECK-DAG:   %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
 // CHECK-DAG:   %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
@@ -73,7 +75,11 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
 // CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<4x4xf32>
 // CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S10]] : tensor<4x4xf32>) -> tensor<4x4xf32>
 // CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:      %[[S13:.*]] = linalg.mul ins(%cst, %[[S11]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[BROADCAST:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S12]] : tensor<4x4xf32>) {
+// CHECK-NEXT:              ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:                linalg.yield %[[IN]] : f32
+// CHECK-NEXT:              } -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.mul ins(%[[BROADCAST]], %[[S11]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
 // CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
 // CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
 // CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>

>From 0bb0f05959eefe7bc828edd568e2e8f2158f3c4c Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 11 Jul 2024 10:03:09 +0100
Subject: [PATCH 19/22] clang-format

---
 mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 4 +---
 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp       | 4 ++--
 2 files changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 03ce455a409bf..bffe7a4e7d62c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3727,9 +3727,7 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
                              winogradConv2D(rewriter, op, getM(), getR());
                          return true;
                        })
-                       .Default([&](Operation *op) {
-                         return false;
-                       });
+                       .Default([&](Operation *op) { return false; });
 
   if (!supported) {
     return emitSilenceableError()
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 53008b876a650..9b8fa7cf6bac1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -13,11 +13,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/MathExtras.h"
 
 namespace mlir {

>From dc9cda1d05cf2346d829c863bf567c4caac2931c Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 11 Jul 2024 14:48:15 +0100
Subject: [PATCH 20/22] remove redundant include path

---
 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 2c5ac273334d6..dd2b251e9234b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"

>From e60957ba0733b4b04da158291db88ea3fb020767 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 11 Jul 2024 15:01:26 +0100
Subject: [PATCH 21/22] fix clang-format errors

---
 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index dd2b251e9234b..0e10ce614472b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -27,7 +27,7 @@ namespace linalg {
 
 namespace {
 
-/// clang-format off
+// clang-format off
 /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
 /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
 /// m is the output dimension and r is the filter dimension, is

>From 5b48c1c9c0ee887a6591e29427d363ba56235a2f Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 12 Jul 2024 15:25:25 +0100
Subject: [PATCH 22/22] Address Max191's comments

---
 .../Linalg/Transforms/WinogradConv2D.cpp      | 260 ++++++++++++------
 .../Linalg/winograd-conv2d-rewrite.mlir       | 180 +++++++-----
 2 files changed, 283 insertions(+), 157 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 0e10ce614472b..5028cd30cbac0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -214,68 +214,121 @@ Value create2DTransformMatrix(OpBuilder &builder, Location loc,
                constVec));
 }
 
-/// Extract height x width data from 4D or 6D tensors.
-Value extract2DData(OpBuilder &builder, Location loc, Value source,
-                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
-                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
-                    int64_t srcSize) {
+/// Extract height x width data from 4D tensors.
+Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
+                          Value loopNorFIndex, Value loopCorFIndex,
+                          Value heightOffset, Value widthOffset,
+                          int64_t extractHeight, int64_t extractWidth,
+                          int64_t loopNorFIdx, int64_t loopCorFIdx,
+                          int64_t heightIdx, int64_t widthIdx) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  int64_t srcSize = sourceType.getRank();
+
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> offsets;
+  offsets.resize(srcSize);
+  offsets[loopNorFIdx] = loopNorFIndex;
+  offsets[loopCorFIdx] = loopCorFIndex;
+  offsets[heightIdx] = heightOffset;
+  offsets[widthIdx] = widthOffset;
+  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
+  sizes[heightIdx] = builder.getIndexAttr(extractHeight);
+  sizes[widthIdx] = builder.getIndexAttr(extractWidth);
+  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
+
+  auto extractFilterType =
+      RankedTensorType::get({extractHeight, extractWidth}, elementType);
+  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
+      loc, extractFilterType, source, offsets, sizes, strides);
+
+  return extractFilterOp;
+}
+
+/// Extract height x width data from 6D tensors.
+Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
+                          Value tileHIndex, Value tileWIndex,
+                          Value loopNorFIndex, Value loopCorFIndex,
+                          int64_t tileHIdx, int64_t tileWIdx,
+                          int64_t loopNorFIdx, int64_t loopCorFIdx,
+                          int64_t heightIdx, int64_t widthIdx) {
   auto sourceType = cast<ShapedType>(source.getType());
   Type elementType = sourceType.getElementType();
   auto sourceShape = sourceType.getShape();
+  int64_t srcSize = sourceType.getRank();
   int64_t height = sourceShape[heightIdx];
   int64_t width = sourceShape[widthIdx];
 
   auto zeroIndex = builder.getIndexAttr(0);
   auto oneIndex = builder.getIndexAttr(1);
-  SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
-  offsets[outLoopIdx] = outLoopIndex;
-  offsets[inLoopIdx] = inLoopIndex;
-  SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
+  SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
+  offsets.resize(srcSize);
+  offsets[tileHIdx] = tileHIndex;
+  offsets[tileWIdx] = tileWIndex;
+  offsets[loopNorFIdx] = loopNorFIndex;
+  offsets[loopCorFIdx] = loopCorFIndex;
+  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
   sizes[heightIdx] = builder.getIndexAttr(height);
   sizes[widthIdx] = builder.getIndexAttr(width);
-  SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
-  SmallVector<int64_t> targetShape(srcSize, 1);
-  targetShape[heightIdx] = height;
-  targetShape[widthIdx] = width;
-
-  auto targetType = RankedTensorType::get(targetShape, elementType);
-  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
-      loc, targetType, source, offsets, sizes, strides);
+  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
 
   auto extractFilterType = RankedTensorType::get({height, width}, elementType);
-  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
-      builder, loc, extractFilterOp, extractFilterType);
+  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
+      loc, extractFilterType, source, offsets, sizes, strides);
 
-  return extractFilter;
+  return extractFilterOp;
 }
 
-/// Insert transformed height x width data to 4D or 6D tensors which it is
+/// Insert transformed height x width data to 4D tensors which it is
 /// extracted from.
-Value insert2DData(OpBuilder &builder, Location loc, Value source, Value dest,
-                   Value outLoopIndex, Value inLoopIndex, int64_t height,
-                   int64_t width, int64_t outLoopIdx, int64_t inLoopIdx,
-                   int64_t heightIdx, int64_t widthIdx, int64_t destSize) {
-  auto sourceType = cast<ShapedType>(source.getType());
-  Type elementType = sourceType.getElementType();
-  SmallVector<int64_t> sliceShape(destSize, 1);
-  sliceShape[heightIdx] = height;
-  sliceShape[widthIdx] = width;
-  auto init = builder.create<tensor::EmptyOp>(loc, sliceShape, elementType);
-  auto result = tensor::createCanonicalRankReducingInsertSliceOp(builder, loc,
-                                                                 source, init);
+Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
+                       Value dest, Value loopNorFIndex, Value loopCorFIndex,
+                       Value heightOffset, Value widthOffset, int64_t height,
+                       int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
+                       int64_t heightIdx, int64_t widthIdx) {
+  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> retOffsets;
+  retOffsets.resize(destSize);
+  retOffsets[loopNorFIdx] = loopNorFIndex;
+  retOffsets[loopCorFIdx] = loopCorFIndex;
+  retOffsets[heightIdx] = heightOffset;
+  retOffsets[widthIdx] = widthOffset;
+  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
+  retSizes[heightIdx] = builder.getIndexAttr(height);
+  retSizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult> strides(destSize, oneIndex);
+
+  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+      loc, source, dest, retOffsets, retSizes, strides);
 
+  return insertSliceOp;
+}
+
+/// Insert transformed height x width data to 6D tensors which it is
+/// extracted from.
+Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
+                       Value dest, Value tileHIndex, Value tileWIndex,
+                       Value loopNorFIndex, Value loopCorFIndex, int64_t height,
+                       int64_t width, int64_t tileHIdx, int64_t tileWIdx,
+                       int64_t loopNorFIdx, int64_t loopCorFIdx,
+                       int64_t heightIdx, int64_t widthIdx) {
+  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
   auto zeroIndex = builder.getIndexAttr(0);
   auto oneIndex = builder.getIndexAttr(1);
-  SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
-  retOffsets[outLoopIdx] = outLoopIndex;
-  retOffsets[inLoopIdx] = inLoopIndex;
-  SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
+  SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
+  retOffsets.resize(destSize);
+  retOffsets[tileHIdx] = tileHIndex;
+  retOffsets[tileWIdx] = tileWIndex;
+  retOffsets[loopNorFIdx] = loopNorFIndex;
+  retOffsets[loopCorFIdx] = loopCorFIndex;
+  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
   retSizes[heightIdx] = builder.getIndexAttr(height);
   retSizes[widthIdx] = builder.getIndexAttr(width);
-  SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
+  SmallVector<OpFoldResult> strides(destSize, oneIndex);
 
   auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
-      loc, result, dest, retOffsets, retSizes, strides);
+      loc, source, dest, retOffsets, retSizes, strides);
 
   return insertSliceOp;
 }
@@ -323,15 +376,17 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
   if (filterW != r && filterW != 1)
     return Value();
 
+  Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
                        ValueRange args) -> scf::ValueVector {
     Value FIter = ivs[0];
     Value CIter = ivs[1];
 
     // Extract (H, W) from (F, H, W, C).
-    auto extractFilter = extract2DData(
-        builder, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
-        /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+    auto extractFilter =
+        extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
+                            zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
+                            /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
 
     TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
@@ -377,16 +432,16 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
     // Insert (H, W) to (H, W, C, F).
     int64_t retHeight = leftTransform ? m + r - 1 : 1;
     int64_t retWidth = rightTransform ? m + r - 1 : 1;
-    auto insertSliceOp = insert2DData(builder, loc, matmulRetValue, args[0],
-                                      FIter, CIter, retHeight, retWidth,
-                                      /*outLoopIdx=*/3, /*inLoopIdx=*/2,
-                                      /*heightIdx=*/0, /*widthIdx=*/1,
-                                      /*destSize=*/4);
+
+    auto insertSliceOp =
+        insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
+                         zeroIdx, zeroIdx, retHeight, retWidth,
+                         /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
+                         /*heightIdx=*/0, /*widthIdx=*/1);
 
     return {insertSliceOp};
   };
 
-  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
   auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -401,12 +456,18 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 /// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
 /// After the transformation, we get
 ///
-/// scf.for %n = lo_n to hi_n step 1
-///   scf.for %c = lo_c to hi_c step 1
-///     %extracted = extract input<h x w> from input<n x h x w x c>
-///     %ret = linalg.matmul BT, %extracted
-///     %ret = linalg.matmul %ret, B
-///     %inserted = insert %ret into input<h x w x n x c>
+/// scf.for %h = 0 to HTile step 1
+///   scf.for %w = 0 to WTile step 1
+///     scf.for %n = 0 to N step 1
+///       scf.for %c = 0 to C step 1
+///         %extracted = extract input<alphaH x alphaW> from
+///                              input<N x H x W x C>
+///                              at [%n, (%h x m), (%w x m), %c]
+///         %ret = linalg.matmul BT, %extracted
+///         %ret = linalg.matmul %ret, B
+///         %inserted = insert %ret into
+///                            input<alphaH x alphaW x tileH x tileW x N x C>
+///                            at [0, 0, %h, %w, %n, %c]
 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
                      Value retValue, int64_t m, int64_t r,
                      bool leftTransform = true, bool rightTransform = true) {
@@ -433,23 +494,38 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
   int64_t inputH = inputShape[1];
   int64_t inputW = inputShape[2];
   int64_t inputC = inputShape[3];
+  auto valueType = cast<ShapedType>(retValue.getType());
+  auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
+  int64_t tileH = valueShape[2];
+  int64_t tileW = valueShape[3];
   int64_t alphaH = leftTransform ? m + r - 1 : 1;
   int64_t alphaW = rightTransform ? m + r - 1 : 1;
 
-  if (inputH != alphaH && inputH != 1)
+  if ((inputH != (tileH * m) + (r - 1)) && inputH != 1)
     return Value();
-  if (inputW != alphaW && inputW != 1)
+  if ((inputW != (tileW * m) + (r - 1)) && inputW != 1)
     return Value();
 
   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
                        ValueRange args) -> scf::ValueVector {
-    Value NIter = ivs[0];
-    Value CIter = ivs[1];
+    Value tileHIter = ivs[0];
+    Value tileWIter = ivs[1];
+    Value NIter = ivs[2];
+    Value CIter = ivs[3];
+
+    auto context = builder.getContext();
+    auto affineMap =
+        AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+    Value heightOffset =
+        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
+    Value widthOffset =
+        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
 
     // Extract (H, W) from (N, H, W, C).
-    auto extractInput = extract2DData(
-        builder, loc, input, NIter, CIter, /*outLoopIdx=*/0,
-        /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+    auto extractInput =
+        extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
+                            widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
+                            /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
 
     TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
@@ -463,7 +539,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
       const TransformMatrix &BTMatrix = it->second;
 
       retRows = BTMatrix.rows;
-      auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+      auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
       auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
                                                   elementType);
 
@@ -494,22 +570,25 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
       matmulRetValue = matmulOp.getResult(0);
     }
 
-    // Insert (H, W) to (H, W, 1, 1, N, C).
-    auto combinedVal = insert2DData(
-        builder, loc, matmulRetValue, args[0], NIter, CIter, retRows, retCols,
-        /*outLoopIdx=*/4, /*inLoopIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1,
-        /*destSize=*/6);
+    // Insert (H, W) to (H, W, tileH, tileW, N, C).
+    auto combinedVal = insert2DDataTo6D(
+        builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
+        CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
+        /*heightIdx=*/0, /*widthIdx=*/1);
 
     return {combinedVal};
   };
 
   auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
+  auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
   auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
   auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
   scf::LoopNest loops = scf::buildLoopNest(
-      rewriter, loc, {zeroIdx, zeroIdx}, {nUpperBound, cUpperBound},
-      {oneStep, oneStep}, {retValue}, buildBody);
+      rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
+      {tileHBound, tileWBound, nUpperBound, cUpperBound},
+      {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
   return loops.results[0];
 }
 
@@ -631,13 +710,16 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
 
   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
                        ValueRange args) -> scf::ValueVector {
-    Value NIter = ivs[0];
-    Value FIter = ivs[1];
+    Value tileHIter = ivs[0];
+    Value tileWIter = ivs[1];
+    Value NIter = ivs[2];
+    Value FIter = ivs[3];
 
-    // Extract (H, W) from (H, W, 1, 1, N, F).
-    auto extractValue = extract2DData(
-        builder, loc, value, NIter, FIter, /*outLoopIdx=*/4,
-        /*inLoopIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1, /*srcSize=*/6);
+    // Extract (H, W) from (H, W, tileH, tileW, N, F).
+    auto extractValue =
+        extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
+                            FIter, 2, 3, /*loopNorFIdx=*/4,
+                            /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
 
     TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
@@ -720,23 +802,37 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
                            .getResult(0);
     }
 
+    auto context = builder.getContext();
+    auto affineMap =
+        AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+    Value heightOffset =
+        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
+    Value widthOffset =
+        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+
     // Insert (H, W) to (N, H, W, F).
-    Value combinedVal = insert2DData(builder, loc, matmulRetValue, args[0],
-                                     NIter, FIter, retRows, retCols,
-                                     /*outLoopIdx=*/0,
-                                     /*inLoopIdx=*/3, /*heightIdx=*/1,
-                                     /*widthIdx=*/2, /*destSize=*/4);
+    Value combinedVal =
+        insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
+                         heightOffset, widthOffset, retRows, retCols,
+                         /*loopNorFIdx=*/0,
+                         /*loopCorFIdx=*/3, /*heightIdx=*/1,
+                         /*widthIdx=*/2);
 
     return {combinedVal};
   };
 
+  int64_t tilwH = valueShape[2];
+  int64_t tileW = valueShape[3];
   auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
+  auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
   auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
   auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
   scf::LoopNest loops = scf::buildLoopNest(
-      rewriter, loc, {zeroIdx, zeroIdx}, {nUpperBound, fUpperBound},
-      {oneStep, oneStep}, {output}, buildBody);
+      rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
+      {tileHBound, tileWBound, nUpperBound, fUpperBound},
+      {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
   return loops.results[0];
 }
 
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index cc5562ff22c99..095a6636b68dc 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -1,90 +1,120 @@
 // RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-decompose-winograd-ops | FileCheck %s
 
-func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<6x6x5x2xf32>
-  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-  %2 = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%2 : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
-  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
-  %4 = tensor.empty() : tensor<36x2x2xf32>
-  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %6 : tensor<2x4x4x2xf32>
+func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %2 = tensor.empty() : tensor<6x6x5x2xf32>
+  %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+  %4 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+  %5 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%4 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+  %6 = tensor.empty() : tensor<36x18x2xf32>
+  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+  %expanded = tensor.expand_shape %7 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+  %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %extracted_slice = tensor.extract_slice %8[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+  return %extracted_slice : tensor<2x9x9x2xf32>
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: func.func @conv2d_4x4_3x3
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
 // CHECK-DAG:   %[[CST:.*]] = arith.constant 1.024000e+03 : f32
 // CHECK-DAG:   %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
 // CHECK-DAG:   %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
 // CHECK-DAG:   %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
 // CHECK-DAG:   %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
 // CHECK-DAG:   %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
 // CHECK-DAG:   %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
 // CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK-NEXT:    %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
-// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
-// CHECK-NEXT:      %[[S8:.*]] = tensor.empty() : tensor<6x3xf32>
-// CHECK-NEXT:      %[[S9:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S8]] : tensor<6x3xf32>) -> tensor<6x3xf32>
-// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[S12]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
-// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x5x2xf32>
-// CHECK-NEXT:    }
-// CHECK-NEXT:    scf.yield %[[S7]] : tensor<6x6x5x2xf32>
-// CHECK-NEXT:  }
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:  %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK-NEXT:    %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
-// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
-// CHECK-NEXT:      %[[S8:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:      %[[S9:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S8]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[S12]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x1x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:    }
-// CHECK-NEXT:    scf.yield %[[S7]] : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:  }
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT:  %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:    %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6x1x1x1x1xf32>
-// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x6xf32>
-// CHECK-NEXT:      %[[S8:.*]] = tensor.empty() : tensor<4x6xf32>
-// CHECK-NEXT:      %[[S9:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S8]] : tensor<4x6xf32>) -> tensor<4x6xf32>
-// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S10]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:      %[[BROADCAST:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S12]] : tensor<4x4xf32>) {
-// CHECK-NEXT:              ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:                linalg.yield %[[IN]] : f32
-// CHECK-NEXT:              } -> tensor<4x4xf32>
-// CHECK-NEXT:      %[[S13:.*]] = linalg.mul ins(%[[BROADCAST]], %[[S11]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
-// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
-// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<2x4x4x2xf32>
-// CHECK-NEXT:    }
-// CHECK-NEXT:    scf.yield %[[S7]] : tensor<2x4x4x2xf32>
-// CHECK-NEXT:  }
-// CHECK-NEXT:  return %[[S6]] : tensor<2x4x4x2xf32>
+// CHECK-DAG:   %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:       %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], %[[C0]], %[[C0]], %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<3x3xf32>
+// CHECK-NEXT:       %[[S8:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:       %[[S9:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S8]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:       %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:       %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[ARG6]][%[[C0]], %[[C0]], %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S7]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT:   ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
+// CHECK-NEXT:     tensor.yield %[[CST_6]] : f32
+// CHECK-NEXT:   } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:       %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:         %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:           %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT:           %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S10]], %[[S11]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S13:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S14:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S15:.*]] = linalg.matmul ins(%[[S13]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG10]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S7]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_7]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK-NEXT:   %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT:   ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
+// CHECK-NEXT:     tensor.yield %[[CST_6]] : f32
+// CHECK-NEXT:   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+// CHECK-NEXT:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:       %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:         %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:           %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:           %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S14]] : tensor<4x4xf32>) {
+// CHECK-NEXT:           ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:             linalg.yield %[[IN]] : f32
+// CHECK-NEXT:           } -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S16:.*]] = linalg.mul ins(%[[S15]], %[[S13]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT:           %[[S18:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG10]][%[[ARG7]], %[[S17]], %[[S18]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S9]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S8]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S7]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }



More information about the Mlir-commits mailing list