[Mlir-commits] [mlir] [mlir][linalg] Implement Winograd Conv2D. (PR #94470)

Hsiangkai Wang llvmlistbot at llvm.org
Mon Jun 17 11:54:26 PDT 2024


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/94470

>From 0b6f8ae55886b3b654c3962132041a047eab3935 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:24:07 +0100
Subject: [PATCH 1/4] [mlir][linalg] Implement Conv2D using Winograd Conv2D
 algorithm

Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       | 119 +++++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 228 ++++++++++++++++++
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 169 +++++++++++++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  14 ++
 6 files changed, 535 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..1d8b4fb482908 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,123 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+  let summary = "Winograd filter transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of filter
+    transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$filter,
+                       AnyRankedTensor:$output,
+                       I64Attr:$output_height,
+                       I64Attr:$output_width,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `output_height` `(` $output_height `)`
+    `output_width` `(` $output_width `)`
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $filter `:` type($filter) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+}
+
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+  let summary = "Winograd input transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of input
+    transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$input,
+                       AnyRankedTensor:$output,
+                       I64Attr:$output_height,
+                       I64Attr:$output_width,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `output_height` `(` $output_height `)`
+    `output_width` `(` $output_width `)`
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+}
+
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+  let summary = "Winograd output transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of output
+    transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$value,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $value `:` type($value) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+}
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..828a2fbfe99f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
+  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..ddf37bbbc4ad0
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,228 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
+  auto type = cast<ShapedType>(data.getType());
+  auto elementType = type.getElementType();
+  auto shape = type.getShape();
+  auto collapseType = RankedTensorType::get(
+      {shape[0] * shape[1], shape[2], shape[3]}, elementType);
+  SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}};
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
+                                                  reassociation);
+}
+
+// This function generates linalg.batch_matmul to multiply input with filter.
+// linalg.batch_matmul only supports 3-dimension data sets. We can treat H x W
+// data as the 1-dimension data array. That is to convert [H, W, N, C] to
+// [H x W, N, C]. In this way, we can convert 4-dimension input data to
+// 3-dimension representation that is suitable for linalg.batch_matmul.
+//
+// Batched matmul will do the matrix multiply with the reduction on channel.
+//
+// We get
+//
+// %collapsed_input = tensor.collapse_shape %input
+// %collapsed_filter = tensor.collapse_shape %filter
+// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+// %expanded_ret = tensor.expand_shape %ret
+//
+// After this function, we get return value with data layout (H, W, N, F)
+//
+Value matrixMultiply(RewriterBase &rewriter, Location loc,
+                     Value transformedFilter, Value transformedInput) {
+  auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
+  auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
+
+  // Batched matrix multiply
+  auto filterType = cast<ShapedType>(transformedFilter.getType());
+  auto filterShape = filterType.getShape();
+  auto inputType = cast<ShapedType>(transformedInput.getType());
+  auto inputElemType = inputType.getElementType();
+  auto inputShape = inputType.getShape();
+
+  auto matmulType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1], inputShape[2], filterShape[3]},
+      inputElemType);
+  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                inputElemType);
+
+  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+      ValueRange{init});
+
+  // Expand matmul result
+  SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}};
+  auto expandType = RankedTensorType::get(
+      {inputShape[0], inputShape[1], inputShape[2], filterShape[3]},
+      inputElemType);
+  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+      loc, expandType, matmulOp.getResult(0), reassociation);
+  return expandOutput;
+}
+
+FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
+                                            linalg::Conv2DNhwcFhwcOp convOp,
+                                            int64_t m, int64_t r) {
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+
+  auto outputType = cast<ShapedType>(output.getType());
+  int64_t outputH = outputType.getShape()[1];
+  int64_t outputW = outputType.getShape()[2];
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputC = inputShape[3];
+
+  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
+  if ((outputH != outputW) && (outputH != 1 && outputW != 1))
+    return failure();
+  if ((filterH != filterW) && (filterH != 1 && filterW != 1))
+    return failure();
+
+  if ((outputH == 1 && filterH != 1) || (outputH != 1 && filterH == 1))
+    return failure();
+  if ((outputW == 1 && filterW != 1) || (outputW != 1 && filterW == 1))
+    return failure();
+
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+      F_2_3, F_4_3, F_2_5};
+
+  TransformMapKeyTy key = {m, r};
+  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+  // If we cannot find the constant transformation matrix, it means we do
+  // not support this configuration yet.
+  if (it == validConfigs.end())
+    return failure();
+
+  // All the criterias are satisfied. We can do Winograd Conv2D.
+  Location loc = convOp.getLoc();
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = outputH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = outputW != 1;
+
+  // Create operator for filter transform
+  Type elementType = filterType.getElementType();
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+  int64_t retHeight = leftTransform ? (outputH / m) * alphaH : 1;
+  int64_t retWidth = rightTransform ? (outputW / m) * alphaW : 1;
+  auto retType = RankedTensorType::get({retHeight, retWidth, filterC, filterF},
+                                       elementType);
+  Value retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
+      loc, retType, filter, retValue, outputH, outputW, m, r);
+
+  // Create operator for input transform
+  retType =
+      RankedTensorType::get({retHeight, retWidth, inputN, inputC}, elementType);
+  retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+  auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
+      loc, retType, input, retValue, outputH, outputW, m, r);
+
+  Value matmulRet =
+      matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
+
+  // create operator for output transform
+  auto transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
+      loc, outputType, matmulRet, output, m, r);
+
+  rewriter.replaceOp(convOp, transformedOutput);
+
+  return transformedOutput.getOperation();
+}
+
+class WinogradConv2DNhwcFhwc final
+    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
+      : OpRewritePattern(context), m(m), r(r) {}
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    Value filter = convOp.getInputs()[1];
+    auto filterType = cast<ShapedType>(filter.getType());
+    auto filterShape = filterType.getShape(); // F, H, W, C
+    int64_t filterH = filterShape[1];
+    int64_t filterW = filterShape[2];
+    Value output = convOp.getOutputs()[0];
+    auto outputType = cast<ShapedType>(output.getType());
+    auto outputShape = outputType.getShape(); // F, H, W, C
+    int64_t outputH = outputShape[1];
+    int64_t outputW = outputShape[2];
+
+    if (filterH != r && filterH != 1 && filterW != r && filterW != 1)
+      return failure();
+
+    if (outputH < m && outputH != 1 && outputW < m && outputW != 1)
+      return failure();
+
+    if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
+      return failure();
+
+    return success();
+  }
+
+private:
+  int64_t m;
+  int64_t r;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
+}
+
+} // end namespace linalg
+} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
new file mode 100644
index 0000000000000..0f827f16b6b81
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -0,0 +1,169 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
+
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %2 : tensor<2x4x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform output_height(4) output_width(4) m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform output_height(4) output_width(4) m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x2x5xf32>) -> tensor<6x6x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_2x2_3x3(%arg0: tensor<2x4x4x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = tensor.empty() : tensor<2x2x2x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x2x2x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x4x4x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %2 : tensor<2x2x2x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_2x2_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x4x4x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<4x4x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform output_height(2) output_width(2) m(2) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<4x4x5x2xf32>) -> tensor<4x4x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<4x4x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform output_height(2) output_width(2) m(2) r(3) ins(%[[ARG0]] : tensor<2x4x4x5xf32>) outs(%[[S4]] : tensor<4x4x2x5xf32>) -> tensor<4x4x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<4x4x5x2xf32> into tensor<16x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<4x4x2x5xf32> into tensor<16x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<16x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<16x2x5xf32>, tensor<16x5x2xf32>) outs(%[[S6]] : tensor<16x2x2xf32>) -> tensor<16x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [4, 4, 2, 2] : tensor<16x2x2xf32> into tensor<4x4x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(3) ins(%[[EXPANDED]] : tensor<4x4x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = tensor.empty() : tensor<2x2x2x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x2x2x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %2 : tensor<2x2x2x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_2x2_5x5
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform output_height(2) output_width(2) m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform output_height(2) output_width(2) m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x2x5xf32>) -> tensor<6x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x1x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x1x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+  return %2 : tensor<2x1x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_1x4_1x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform output_height(1) output_width(4) m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform output_height(1) output_width(4) m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x2x5xf32>) -> tensor<1x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x1x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x1x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %2 : tensor<2x4x1x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_4x1_3x1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform output_height(4) output_width(1) m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform output_height(4) output_width(1) m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x2x5xf32>) -> tensor<6x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT: }
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4892fa2f99a7c..4904bb24209ba 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,6 +123,10 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
+  Option<bool> testWinogradConv2D{
+      *this, "test-winograd-conv2d",
+      llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -207,6 +211,14 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyWinogradConv2D(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
+  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/3);
+  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -231,6 +243,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
+  if (testWinogradConv2D)
+    return applyWinogradConv2D(getOperation());
 }
 
 namespace mlir {

>From 56d6983369084f64f3d998d5278c85b23522b591 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:49:08 +0100
Subject: [PATCH 2/4] [mlir][linalg] Add transform operator for Winograd Conv2D
 algorithm

Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
---
 .../Linalg/TransformOps/LinalgTransformOps.td | 51 +++++++++++++++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  7 +++
 .../TransformOps/LinalgTransformOps.cpp       | 25 +++++++++
 .../Linalg/Transforms/WinogradConv2D.cpp      |  6 +++
 .../Linalg/transform-winograd-conv2d.mlir     | 41 +++++++++++++++
 5 files changed, 130 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..68d0f713caad4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+    "structured.winograd_conv2d",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    convolution.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       I64Attr:$m,
+                       I64Attr:$r);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 828a2fbfe99f7..3d61d72b924db 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
                                             linalg::BatchMatmulOp op,
                                             bool transposeLHS = true);
 
+/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
+/// F(m x m, r x r). m is the dimension size of output and r is the dimension
+/// size of filter.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3a..fa611b4b93cfb 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+            return winogradConv2D(rewriter, op, getM(), getR());
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index ddf37bbbc4ad0..e71896828f8e8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -218,6 +218,12 @@ class WinogradConv2DNhwcFhwc final
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r) {
+  return winogradConv2DHelper(rewriter, op, m, r);
+}
+
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
new file mode 100644
index 0000000000000..231f87459f230
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<2x8x8x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x8x8x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %2 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<12x12x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform output_height(8) output_width(8) m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<12x12x5x2xf32>) -> tensor<12x12x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<12x12x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform output_height(8) output_width(8) m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<12x12x2x5xf32>) -> tensor<12x12x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<12x12x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<12x12x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [12, 12, 2, 2] : tensor<144x2x2xf32> into tensor<12x12x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<12x12x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }

>From 646d3e7be656644755a1bb205d3d0f514c49f989 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 17:39:49 +0100
Subject: [PATCH 3/4] [mlir][linalg] Decompose winograd operators

Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../Dialect/Linalg/Transforms/Transforms.h    |   3 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 769 ++++++++++++++++++
 .../Linalg/winograd-conv2d-rewrite.mlir       | 105 +++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  11 +
 4 files changed, 888 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3d61d72b924db..f6765c80a8626 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1703,6 +1703,9 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r);
 
+/// Patterns to decompose Winograd operators.
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index e71896828f8e8..d815f0539e729 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -12,7 +12,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -22,6 +25,156 @@ namespace linalg {
 
 namespace {
 
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dimension, is
+//
+// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+//
+// g is filter and d is input data. We need to prepare 6 constant
+// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+//
+// The following tables define these constant transformation matrices for
+// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+   -1,     0,   0,
+ 1./2, -1./2, 1./2,
+ 1./2,  1./2, 1./2,
+    0,     0,    1
+};
+
+constexpr float GT_2x2_3x3[] = {
+   -1,  1./2, 1./2, 0,
+    0, -1./2, 1./2, 0,
+    0,  1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+   -1,    0,   1,   0,
+    0,   -1,   1,   0,
+    0,    1,   1,   0,
+    0,   -1,   0,   1
+};
+
+constexpr float B_2x2_3x3[] = {
+   -1,    0,   0,   0,
+    0,   -1,   1,  -1,
+    1,    1,   1,   0,
+    0,    0,   0,   1
+};
+
+constexpr float AT_2x2_3x3[] = {
+    1,    1,   1,   0,
+    0,   -1,   1,   1
+};
+
+constexpr float A_2x2_3x3[] = {
+    1,    0,
+    1,   -1,
+    1,    1,
+    0,    1
+};
+
+constexpr float G_4x4_3x3[] = {
+     1,     0,     0,
+ -1./3,  1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6,  1./3,
+ 1./12,  1./6,  1./3,
+     0,     0,     1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1,  -1./3, -1./3, 1./12, 1./12, 0,
+ 0,   1./3, -1./3, -1./6,  1./6, 0,
+ 0,  -1./3, -1./3,  1./3,  1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4,     0, -5./16,      0, 1./16,     0,
+    0,  1./4,  -1./4, -1./16, 1./16,     0,
+    0, -1./4,  -1./4,  1./16, 1./16,     0,
+    0,  1./4,  -1./8,  -1./4,  1./8,     0,
+    0, -1./4,  -1./8,   1./4,  1./8,     0,
+    0,  1./4,      0, -5./16,     0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+   1./4,      0,     0,     0,     0,      0,
+      0,   1./4, -1./4,  1./4, -1./4,   1./4,
+ -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
+      0, -1./16, 1./16, -1./4,  1./4, -5./16,
+  1./16,  1./16, 1./16,  1./8,  1./8,      0,
+      0,      0,     0,     0,     0,  1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8,  1./4, 1./4,  1./8, 1./8,    0,
+    0, -1./4, 1./4, -1./4, 1./4,    0,
+    0,  1./4, 1./4,  1./2, 1./2,    0,
+    0, -1./4, 1./4,    -1,    1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+  1./8,     0,    0,     0,
+  1./4, -1./4, 1./4, -1./4,
+  1./4,  1./4, 1./4,  1./4,
+  1./8, -1./4, 1./2,    -1,
+  1./8,  1./4, 1./2,     1,
+     0,     0,    0,  1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+     1,     0,      0,      0,      0,
+  1./6, -1./6,   1./6,  -1./6,   1./6,
+ -1./6, -1./6,  -1./6,  -1./6,  -1./6,
+-4./15, 2./15, -1./15,  1./30, -1./60,
+ 1./60, 1./30,  1./15,  2./15,  4./15,
+     0,     0,      0,      0,      1
+};
+
+constexpr float GT_2x2_5x5[] = {
+   1,  1./6, -1./6, -4./15, 1./60, 0,
+   0, -1./6, -1./6,  2./15, 1./30, 0,
+   0,  1./6, -1./6, -1./15, 1./15, 0,
+   0, -1./6, -1./6,  1./30, 2./15, 0,
+   0,  1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
+    0,   1./8,  1./16,  -5./16,   1./8,    0,
+    0,  -1./8, -5./16,  -1./16,   1./8,    0,
+    0,   1./4,  -1./8,   -1./4,   1./8,    0,
+    0,  -1./8,  -1./4,    1./8,   1./4,    0,
+    0,   1./8,  3./16,   -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+   1./8,      0,      0,     0,     0,      0,
+  3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
+  -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
+ -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
+   1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
+      0,      0,      0,     0,     0,   1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+  1./2,  1, 1,  2, 1,    0,
+     0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2,    0,
+    1,   -1,
+    1,    1,
+    2,   -1,
+    1,    2,
+    0, 1./2
+};
+// clang-format on
+
 using TransformMapKeyTy = std::pair<int, int>;
 
 // We use F(m, r) to define the size of minimal filtering algorithms.
@@ -35,6 +188,90 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
+
+  return rewriter.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               const_vec));
+}
+
+Value extract2DData(RewriterBase &rewriter, Location loc, Value source,
+                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  auto sourceShape = sourceType.getShape();
+  int64_t height = sourceShape[heightIdx];
+  int64_t width = sourceShape[widthIdx];
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+  offsets[outLoopIdx] = outLoopIndex;
+  offsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 4> sizes(4, oneIndex);
+  sizes[heightIdx] = rewriter.getIndexAttr(height);
+  sizes[widthIdx] = rewriter.getIndexAttr(width);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+  SmallVector<int64_t> targetShape(4, 1);
+  targetShape[heightIdx] = height;
+  targetShape[widthIdx] = width;
+
+  auto targetType = RankedTensorType::get(targetShape, elementType);
+  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, targetType, source, offsets, sizes, strides);
+
+  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, extractFilterOp, extractFilterType);
+
+  return extractFilter;
+}
+
+Value insert2DData(RewriterBase &rewriter, Location loc, Value source,
+                   Value dest, Value outLoopIndex, Value inLoopIndex,
+                   int64_t height, int64_t width, int64_t outLoopIdx,
+                   int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  SmallVector<int64_t> sliceShape(4, 1);
+  sliceShape[heightIdx] = height;
+  sliceShape[widthIdx] = width;
+  auto init = rewriter.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(rewriter, loc,
+                                                                 source, init);
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> retOffsets(4, zeroIndex);
+  retOffsets[outLoopIdx] = outLoopIndex;
+  retOffsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 4> retSizes(4, oneIndex);
+  retSizes[heightIdx] = rewriter.getIndexAttr(height);
+  retSizes[widthIdx] = rewriter.getIndexAttr(width);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, result, dest, retOffsets, retSizes, strides);
+
+  return insertSliceOp;
+}
+
 Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
   auto type = cast<ShapedType>(data.getType());
   auto elementType = type.getElementType();
@@ -46,6 +283,259 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
                                                   reassociation);
 }
 
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+// After the transformation, we get
+//
+// scf.for %f = lo_f to hi_f step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract filter<h x w> from filter<f x h x w x c>
+//     %ret = linalg.matmul G, %extracted
+//     %ret = linalg.matmul %ret, GT
+//     %inserted = insert %ret into filter<h x w x c x f>
+//
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      Value retValue, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GMatrices = {
+          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+      };
+
+  // Map from (m, r) to GT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GTMatrices = {
+          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+      };
+
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+
+  if (filterH != r && filterH != 1)
+    return Value();
+  if (filterW != r && filterW != 1)
+    return Value();
+
+  // Return shape is <H x W x C x F>
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value FIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (F, H, W, C)
+  auto extractFilter =
+      extract2DData(rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+                    /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  Value matmulRetValue = extractFilter;
+  if (leftTransform) {
+    // Get constant transform matrix G
+    auto it = GMatrices.find(key);
+    if (it == GMatrices.end())
+      return Value();
+    const TransformMatrix &GMatrix = it->second;
+
+    retRows = GMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType);
+    // Multiply G x g
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix GT
+    auto it = GTMatrices.find(key);
+    if (it == GTMatrices.end())
+      return Value();
+    const TransformMatrix &GTMatrix = it->second;
+
+    auto matmulType =
+        RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType);
+    // Multiply u = (G x g) x GT
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert (H, W) to (H, W, C, F)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  int64_t retHeight = leftTransform ? m + r - 1 : 1;
+  int64_t retWidth = rightTransform ? m + r - 1 : 1;
+  auto insertSliceOp = insert2DData(
+      rewriter, loc, matmulRetValue, iterArg, FIter, CIter, retHeight, retWidth,
+      /*outLoopIdx=*/3, /*inLoopIdx=*/2, /*heightIdx=*/0, /*widthIdx=*/1);
+
+  rewriter.create<scf::YieldOp>(loc, insertSliceOp);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+// This function transforms the input. The data layout of the input is NHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract input<h x w> from input<n x h x w x c>
+//     %ret = linalg.matmul BT, %extracted
+//     %ret = linalg.matmul %ret, B
+//     %inserted = insert %ret into input<h x w x n x c>
+//
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     Value retValue, int64_t m, int64_t r,
+                     bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to BT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BTMatrices = {
+          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+      };
+
+  // Map from (m, r) to B transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BMatrices = {
+          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+      };
+
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (inputH != alphaH && inputH != 1)
+    return Value();
+  if (inputW != alphaW && inputW != 1)
+    return Value();
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (N, H, W, C)
+  auto extractInput =
+      extract2DData(rewriter, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+                    /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  Value matmulRetValue = extractInput;
+  if (leftTransform) {
+    // Get constant transform matrix BT
+    auto it = BTMatrices.find(key);
+    if (it == BTMatrices.end())
+      return Value();
+    const TransformMatrix &BTMatrix = it->second;
+
+    retRows = BTMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value BT =
+        create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type());
+    // Multiply BT x d
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix B
+    auto it = BMatrices.find(key);
+    if (it == BMatrices.end())
+      return Value();
+    const TransformMatrix &BMatrix = it->second;
+
+    retCols = BMatrix.cols;
+    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+    Value B =
+        create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type());
+    // Multiply v = (BT x d) x B
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert v
+  // Insert (H, W) to (H, W, N, C)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  auto combinedVal = insert2DData(
+      rewriter, loc, matmulRetValue, iterArg, NIter, CIter, retRows, retCols,
+      /*outLoopIdx=*/2, /*inLoopIdx=*/3, /*heightIdx=*/0, /*widthIdx=*/1);
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
 // This function generates linalg.batch_matmul to multiply input with filter.
 // linalg.batch_matmul only supports 3-dimension data sets. We can treat H x W
 // data as the 1-dimension data array. That is to convert [H, W, N, C] to
@@ -95,6 +585,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc,
   return expandOutput;
 }
 
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %f = lo_f to hi_f step 1
+//     %extracted = extract input<h x w> from result<h x w x n x f>
+//     %ret = linalg.matmul AT, %extracted
+//     %ret = linalg.matmul %ret, A
+//     %inserted = insert %ret into ret<n x h x w x f>
+//
+Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
+                      Value output, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to AT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      ATMatrices = {
+          {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+          {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+          {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+      };
+
+  // Map from (m, r) to A transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      AMatrices = {
+          {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+          {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+          {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+      };
+
+  auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
+  auto valueShape = valueType.getShape(); // H, W, N, F
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  int64_t valueN = valueShape[2];
+  int64_t valueF = valueShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (valueH != alphaH && valueH != 1)
+    return Value();
+  if (valueW != alphaW && valueW != 1)
+    return Value();
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value FIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (H, W, N, F)
+  auto extractValue =
+      extract2DData(rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/2,
+                    /*inLoopIdx=*/3, /*heightIdx=*/0, /*widthIdx=*/1);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  int64_t leftScalarFactor = 1;
+  int64_t rightScalarFactor = 1;
+  Value matmulRetValue = extractValue;
+  if (leftTransform) {
+    // Get constant transform matrix AT
+    auto it = ATMatrices.find(key);
+    if (it == ATMatrices.end())
+      return Value();
+    const TransformMatrix &ATMatrix = it->second;
+
+    leftScalarFactor = ATMatrix.scalarFactor;
+    retRows = ATMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType);
+    // Multiply AT x m
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix T
+    auto it = AMatrices.find(key);
+    if (it == AMatrices.end())
+      return Value();
+    const TransformMatrix &AMatrix = it->second;
+
+    rightScalarFactor = AMatrix.scalarFactor;
+    auto matmulType =
+        RankedTensorType::get({retRows, AMatrix.cols}, elementType);
+    retCols = AMatrix.cols;
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType);
+    // Multiply y = (AT x m) x A
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Multiply scalar factor.
+  Value scalarFactor = rewriter.create<arith::ConstantOp>(
+      loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+  auto init =
+      rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType);
+
+  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
+  SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()),
+                                       identityAffineMap, identityAffineMap};
+  auto scalarMatrixOp = rewriter.create<linalg::GenericOp>(
+      loc, matmulType, ValueRange{scalarFactor, matmulRetValue},
+      ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2),
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value scalarVal = args[0];
+        Value matrixVal = args[1];
+        Value result = nestedBuilder.create<arith::MulFOp>(nestedLoc, scalarVal,
+                                                           matrixVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
+
+  // Insert slice y
+  // Insert (H, W) to (N, H, W, F)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  Value combinedVal =
+      insert2DData(rewriter, loc, scalarMatrixOp.getResult(0), iterArg, NIter,
+                   FIter, retRows, retCols,
+                   /*outLoopIdx=*/0,
+                   /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
 FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
                                             linalg::Conv2DNhwcFhwcOp convOp,
                                             int64_t m, int64_t r) {
@@ -179,6 +824,123 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
   return transformedOutput.getOperation();
 }
 
+FailureOr<Operation *>
+decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradFilterTransformOp op) {
+  Location loc = op.getLoc();
+  Value filter = op.getFilter();
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  Value transformedFilter =
+      filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedFilter)
+    return failure();
+
+  rewriter.replaceOp(op, transformedFilter);
+
+  return transformedFilter.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
+                                      linalg::WinogradInputTransformOp op) {
+  Location loc = op.getLoc();
+  Value input = op.getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = inputH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = inputW != 1;
+  Value transformedInput =
+      inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
+                     op.getR(), leftTransform, rightTransform);
+  if (!transformedInput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedInput);
+
+  return transformedInput.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradOutputTransformOp op) {
+  Location loc = op.getLoc();
+  Value value = op.getValue();
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = valueH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = valueW != 1;
+  Value transformedOutput =
+      outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedOutput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+class DecomposeWinogradFilterTransform final
+    : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradFilterTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class DecomposeWinogradInputTransform final
+    : public OpRewritePattern<linalg::WinogradInputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradInputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class DecomposeWinogradOutputTransform final
+    : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradOutputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
 class WinogradConv2DNhwcFhwc final
     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
 public:
@@ -230,5 +992,12 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
   patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
 }
 
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<DecomposeWinogradFilterTransform>(context);
+  patterns.insert<DecomposeWinogradInputTransform>(context);
+  patterns.insert<DecomposeWinogradOutputTransform>(context);
+}
+
 } // end namespace linalg
 } // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
new file mode 100644
index 0000000000000..092d30f2b7896
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-decompose-winograd-ops | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = tensor.empty() : tensor<6x6x5x2xf32>
+  %3 = linalg.winograd_filter_transform output_height(4) output_width(4) m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %4 = tensor.empty() : tensor<6x6x2x5xf32>
+  %5 = linalg.winograd_input_transform output_height(4) output_width(4) m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%4 : tensor<6x6x2x5xf32>) -> tensor<6x6x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+  %6 = tensor.empty() : tensor<36x2x2xf32>
+  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %expanded = tensor.expand_shape %7 [[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %8 : tensor<2x4x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:   %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:   %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:   %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:   %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S4]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x2x5xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S1]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x2x2xf32> to tensor<6x6x1x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) {
+// CHECK-NEXT:      ^bb0(%[[IN:.*]]: f32, %[[IN_9:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:        %[[S17:.*]] = arith.mulf %[[IN]], %[[IN_9]] : f32
+// CHECK-NEXT:        linalg.yield %[[S17]] : f32
+// CHECK-NEXT:      } -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S16:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4904bb24209ba..a46e8e3113678 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -127,6 +127,9 @@ struct TestLinalgTransforms
       *this, "test-winograd-conv2d",
       llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
       llvm::cl::init(false)};
+  Option<bool> testDecomposeWinogradOps{
+      *this, "test-decompose-winograd-ops",
+      llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
 };
 } // namespace
 
@@ -219,6 +222,12 @@ static void applyWinogradConv2D(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateDecomposeWinogradOpsPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -245,6 +254,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnnecessaryInputs(getOperation());
   if (testWinogradConv2D)
     return applyWinogradConv2D(getOperation());
+  if (testDecomposeWinogradOps)
+    return applyDecomposeWinogradOps(getOperation());
 }
 
 namespace mlir {

>From 93112d10b0726516989366c86a4d6f32c7d2e3e1 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:44:27 +0100
Subject: [PATCH 4/4] [mlir][linalg] Implement TilingInterface for winograd
 operators

In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operators. Before converting winograd
operators into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.

Add a transform operator structured.decompose_winograd_op to decompose
winograd operators. Before applying the transform op, use tile_using_for
to tile the input data into supported size. The test case shows how to
tile and decompose winograd operators.
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  21 +-
 .../Linalg/TransformOps/LinalgTransformOps.td |  37 ++
 .../Dialect/Linalg/Transforms/Transforms.h    |  45 +++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 329 ++++++++++++++++++
 .../TransformOps/LinalgTransformOps.cpp       |  27 ++
 .../Linalg/Transforms/WinogradConv2D.cpp      |  18 +
 .../transform-tile-and-winograd-rewrite.mlir  | 166 +++++++++
 7 files changed, 640 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 1d8b4fb482908..657691ea0e23a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,7 +154,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+     ["getIterationDomain",
+      "getLoopIteratorTypes",
+      "getResultTilePosition",
+      "getTiledImplementation"]>]> {
   let summary = "Winograd filter transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -195,7 +200,12 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
   }];
 }
 
-def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Winograd input transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -236,7 +246,12 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
   }];
 }
 
-def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Winograd output transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 68d0f713caad4..71736eae38b4f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2638,4 +2638,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
   }];
 }
 
+def DecomposeWinogradOp : Op<Transform_Dialect,
+    "structured.decompose_winograd_op",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Decompose winograd operators. It will convert filter, input and output
+    transform operators into a combination of scf, tensor, and linalg
+    equivalent operators. Before applying this transform operator, users
+    need to tile winograd transform operators into supported sizes.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    operator.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f6765c80a8626..082ad3ee92039 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1319,6 +1319,51 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
                                       linalg::Conv2DNhwcFhwcOp op, int64_t m,
                                       int64_t r);
 
+/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
+/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
+/// C. After the rewriting, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract filter<h x w> from filter<f x h x w x c>
+///     %ret = linalg.matmul G, %extracted
+///     %ret = linalg.matmul %ret, GT
+///     %inserted = insert %ret into filter<h x w x c x f>
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradFilterTransformOp op);
+
+/// Rewrite linalg.winograd_input_transform. The data layout of the input is
+/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
+/// C. After the rewriting, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract input<h x w> from input<n x h x w x c>
+///     %ret = linalg.matmul BT, %extracted
+///     %ret = linalg.matmul %ret, B
+///     %inserted = insert %ret into input<h x w x n x c>
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+                                  linalg::WinogradInputTransformOp op);
+
+/// Rewrite linalg.winograd_output_transform. The data layout of the output is
+/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
+/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
+/// F. After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %f = lo_f to hi_f step 1
+///     %extracted = extract input<h x w> from result<h x w x n x f>
+///     %ret = linalg.matmul AT, %extracted
+///     %ret = linalg.matmul %ret, A
+///     %inserted = insert %ret into ret<n x h x w x f>
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradOutputTransformOp op);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..9e0dfc3c9361f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2737,6 +2737,335 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+SmallVector<Range>
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+  SmallVector<Range> loopBounds(4);
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  IntegerAttr heightAttr = getOutputHeightAttr();
+  IntegerAttr widthAttr = getOutputWidthAttr();
+  Value output = getOutput();
+  for (unsigned dim = 0; dim < 4; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = one;
+  }
+  // Iterate on output domain.
+  loopBounds[0].size = heightAttr;
+  loopBounds[1].size = widthAttr;
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradFilterTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(4,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+                               Location loc) {
+  if (auto val = opFoldResult.dyn_cast<Value>()) {
+    return val;
+  } else if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
+    auto intAttr = cast<IntegerAttr>(attr);
+    return builder.create<arith::ConstantOp>(loc, intAttr);
+  }
+  // This should never happen if OpFoldResult is correctly formed.
+  return nullptr;
+}
+
+LogicalResult WinogradFilterTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value filter = getFilter();
+  auto filterType = cast<ShapedType>(filter.getType());
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = filterH != 1 ? alpha : 1;
+  int64_t alphaW = filterW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  auto context = builder.getContext();
+  auto affineMap = AffineMap::get(
+      1, 0, {builder.getAffineDimExpr(0).floorDiv(m) * alpha}, context);
+
+  Location loc = getLoc();
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+
+  resultOffsets.push_back(mappedOffset1);
+  resultOffsets.push_back(mappedOffset2);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(alphaHAttr);
+  resultSizes.push_back(alphaWAttr);
+  resultSizes.push_back(sizes[2]);
+  resultSizes.push_back(sizes[3]);
+  return success();
+}
+
+FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  Location loc = getLoc();
+  SmallVector<OpFoldResult> strides(4, oneAttr);
+  SmallVector<Value> tiledOperands;
+  tiledOperands.emplace_back(getFilter());
+
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+  SmallVector<Range> loopBounds(4);
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  IntegerAttr heightAttr = getOutputHeightAttr();
+  IntegerAttr widthAttr = getOutputWidthAttr();
+  Value output = getOutput();
+  for (unsigned dim = 0; dim < 4; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = one;
+  }
+  loopBounds[0].size = heightAttr;
+  loopBounds[1].size = widthAttr;
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(4,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value input = getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  auto context = builder.getContext();
+  auto affineMap = AffineMap::get(
+      1, 0, {builder.getAffineDimExpr(0).floorDiv(m) * alpha}, context);
+
+  Location loc = getLoc();
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+
+  resultOffsets.push_back(mappedOffset1);
+  resultOffsets.push_back(mappedOffset2);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(alphaHAttr);
+  resultSizes.push_back(alphaWAttr);
+  resultSizes.push_back(sizes[2]);
+  resultSizes.push_back(sizes[3]);
+  return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+                                                 ArrayRef<OpFoldResult> offsets,
+                                                 ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value input = getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  Location loc = getLoc();
+  SmallVector<OpFoldResult> strides(4, oneAttr);
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(offsets[0]);
+  sliceOffsets.push_back(offsets[1]);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(sizes[2]);
+  sliceSizes.push_back(alphaHAttr);
+  sliceSizes.push_back(alphaWAttr);
+  sliceSizes.push_back(sizes[3]);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getInput(), sliceOffsets, sliceSizes, strides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+  SmallVector<Range> loopBounds(4);
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value output = getOutput();
+  for (unsigned dim = 0; dim < 4; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = one;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(4,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  IntegerAttr mAttr = getMAttr();
+
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(offsets[1]);
+  resultOffsets.push_back(offsets[2]);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(sizes[0]);
+  resultSizes.push_back(mAttr);
+  resultSizes.push_back(mAttr);
+  resultSizes.push_back(sizes[3]);
+  return success();
+}
+
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value value = getValue();
+  auto valueType = cast<ShapedType>(value.getType());
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = valueH != 1 ? alpha : 1;
+  int64_t alphaW = valueW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+  Location loc = getLoc();
+  SmallVector<OpFoldResult> strides(4, oneAttr);
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  auto context = builder.getContext();
+  auto affineMap = AffineMap::get(
+      1, 0, {builder.getAffineDimExpr(0).floorDiv(m) * alpha}, context);
+
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+
+  sliceOffsets.push_back(mappedOffset1);
+  sliceOffsets.push_back(mappedOffset2);
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(alphaHAttr);
+  sliceSizes.push_back(alphaWAttr);
+  sliceSizes.push_back(sizes[0]);
+  sliceSizes.push_back(sizes[3]);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, value, sliceOffsets, sliceSizes, strides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index fa611b4b93cfb..7729d2bccbac3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3505,6 +3505,33 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::WinogradFilterTransformOp op) {
+            return decomposeWinogradFilterTransformOp(rewriter, op);
+          })
+          .Case([&](linalg::WinogradInputTransformOp op) {
+            return decomposeWinogradInputTransformOp(rewriter, op);
+          })
+          .Case([&](linalg::WinogradOutputTransformOp op) {
+            return decomposeWinogradOutputTransformOp(rewriter, op);
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index d815f0539e729..63078cda2402b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -986,6 +986,24 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
   return winogradConv2DHelper(rewriter, op, m, r);
 }
 
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradFilterTransformOp op) {
+  return decomposeWinogradFilterTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+                                  linalg::WinogradInputTransformOp op) {
+  return decomposeWinogradInputTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradOutputTransformOp op) {
+  return decomposeWinogradOutputTransformHelper(rewriter, op);
+}
+
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
new file mode 100644
index 0000000000000..8e50faf28b29b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -0,0 +1,166 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<2x8x8x2xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x8x8x2xf32>
+  %2 = tensor.empty() : tensor<12x12x5x2xf32>
+  %3 = linalg.winograd_filter_transform output_height(8) output_width(8) m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<12x12x5x2xf32>) -> tensor<12x12x5x2xf32>
+  %4 = tensor.empty() : tensor<12x12x2x5xf32>
+  %5 = linalg.winograd_input_transform output_height(8) output_width(8) m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%4 : tensor<12x12x2x5xf32>) -> tensor<12x12x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<12x12x5x2xf32> into tensor<144x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2], [3]] : tensor<12x12x2x5xf32> into tensor<144x2x5xf32>
+  %6 = tensor.empty() : tensor<144x2x2xf32>
+  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%6 : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+  %expanded = tensor.expand_shape %7 [[0, 1], [2], [3]] output_shape [12, 12, 2, 2] : tensor<144x2x2xf32> into tensor<12x12x2x2xf32>
+  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<12x12x2x2xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %8 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [4, 4, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [4, 4, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 4, 4, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %6 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %1 : (!transform.any_op) -> !transform.any_op
+    %7 = transform.structured.decompose_winograd_op %6 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ((d0 floordiv 4) * 6)>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-DAG:    %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:    %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:    %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:    %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:    %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:    %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:    %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:    %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:    %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<12x12x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = tensor.empty() : tensor<12x12x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG4:.*]] = %[[S3]]) -> (tensor<12x12x5x2xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<12x12x5x2xf32>) {
+// CHECK-NEXT:       %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S2]][%[[S13]], %[[S14]], 0, 0] [6, 6, 5, 2] [1, 1, 1, 1] : tensor<12x12x5x2xf32> to tensor<6x6x5x2xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:         %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_7]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_8]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S19]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[INSERTED_SLICE_9]] into %[[ARG10]][0, 0, %[[ARG9]], %[[ARG7]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_10]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S18]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[S16:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S17:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[S16]], %[[S17]], 0, 0] [6, 6, 5, 2] [1, 1, 1, 1] : tensor<6x6x5x2xf32> into tensor<12x12x5x2xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<12x12x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<12x12x5x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[S5:.*]] = tensor.empty() : tensor<12x12x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<12x12x2x5xf32>
+// CHECK-NEXT:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<12x12x2x5xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<12x12x2x5xf32>) {
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[ARG3]], %[[ARG5]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
+// CHECK-NEXT:       %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][%[[S13]], %[[S14]], 0, 0] [6, 6, 2, 5] [1, 1, 1, 1] : tensor<12x12x2x5xf32> to tensor<6x6x2x5xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:         %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE]]_9 : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S19]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x2x5xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S18]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[S16:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S17:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[S16]], %[[S17]], 0, 0] [6, 6, 2, 5] [1, 1, 1, 1] : tensor<6x6x2x5xf32> into tensor<12x12x2x5xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<12x12x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<12x12x2x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %4 {{\[}}[0, 1], [2], [3]] : tensor<12x12x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S7]] {{\[}}[0, 1], [2], [3]] : tensor<12x12x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT:   %[[S8:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[S9:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S8]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2], [3]] output_shape [12, 12, 2, 2] : tensor<144x2x2xf32> into tensor<12x12x2x2xf32>
+// CHECK-NEXT:   %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S11:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:       %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %expanded[%[[S13]], %[[S14]], 0, 0] [6, 6, 2, 2] [1, 1, 1, 1] : tensor<12x12x2x2xf32> to tensor<6x6x2x2xf32>
+// CHECK-NEXT:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S1]][0, %[[ARG3]], %[[ARG5]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:         %[[S16:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x2x2xf32> to tensor<6x6x1x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S17:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:           %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]], #[[$MAP4]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S20]] : f32, tensor<4x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) {
+// CHECK-NEXT:           ^bb0(%[[IN:.*]]: f32, %[[IN_12:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:             %[[S24:.*]] = arith.mulf %[[IN]], %[[IN_12]] : f32
+// CHECK-NEXT:             linalg.yield %[[S24]] : f32
+// CHECK-NEXT:           } -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_11]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S16]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[ARG3]], %[[ARG5]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S11]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }



More information about the Mlir-commits mailing list