[Mlir-commits] [mlir] [mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm (PR #96181)

Hsiangkai Wang llvmlistbot at llvm.org
Sat Jun 29 13:58:18 PDT 2024


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/96181

>From 4240341b4f06f1b77f63b0f619cae3804d88eb68 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:24:07 +0100
Subject: [PATCH 1/5] [mlir][linalg] Implement Conv2D using Winograd Conv2D
 algorithm

Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       | 114 +++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  78 +++++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 321 ++++++++++++++++++
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 248 ++++++++++++++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  13 +
 7 files changed, 779 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..de1097b6ac27b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+  let summary = "Winograd filter transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of filter
+    transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$filter,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $filter `:` type($filter) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+  let summary = "Winograd input transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of input
+    transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$input,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+  let summary = "Winograd output transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of output
+    transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$value,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $value `:` type($value) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..835aeaf2ffed3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..7bf2a5bca037f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,6 +2734,84 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradFilterTransformOp::verify() {
+  auto filterType = cast<ShapedType>(getFilter().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto filterElemType = filterType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (filterElemType != outputElemType) {
+    return emitOpError() << "expected element type of input " << filterElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned filterRank = filterType.getRank();
+  if (filterRank != 4)
+    return emitOpError() << "expected rank of input is 4";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 6)
+    return emitOpError() << "expected rank of output is 6";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradInputTransformOp::verify() {
+  auto inputType = cast<ShapedType>(getInput().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto inputElemType = inputType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (inputElemType != outputElemType) {
+    return emitOpError() << "expected element type of input " << inputElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned inputRank = inputType.getRank();
+  if (inputRank != 4)
+    return emitOpError() << "expected rank of input is 4";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 6)
+    return emitOpError() << "expected rank of output is 6";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+  auto valueType = cast<ShapedType>(getValue().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto valueElemType = valueType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (valueElemType != outputElemType) {
+    return emitOpError() << "expected element type of value " << valueElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned valueRank = valueType.getRank();
+  if (valueRank != 6)
+    return emitOpError() << "expected rank of input is 6";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 4)
+    return emitOpError() << "expected rank of output is 4";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
+  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..86e834d51f2fc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,321 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
+  auto type = cast<ShapedType>(data.getType());
+  auto elementType = type.getElementType();
+  auto shape = type.getShape();
+  auto collapseType = RankedTensorType::get(
+      {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
+      elementType);
+  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
+                                                  reassociation);
+}
+
+// This function generates linalg.batch_matmul to multiply input with filter.
+// linalg.batch_matmul only supports 3-dimension data sets. We can treat
+// tileH x tileW x H x W data as the 1-dimension data array. That is to convert
+// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we
+// can convert 6-dimension input data to 3-dimension representation that is
+// suitable for linalg.batch_matmul.
+//
+// Batched matmul will do the matrix multiply with the reduction on channel.
+//
+// We get
+//
+// %collapsed_input = tensor.collapse_shape %input
+// %collapsed_filter = tensor.collapse_shape %filter
+// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+// %expanded_ret = tensor.expand_shape %ret
+//
+// After this function, we get return value with data layout
+// (tileH, tileW, H, W, N, F).
+Value matrixMultiply(RewriterBase &rewriter, Location loc,
+                     Value transformedFilter, Value transformedInput) {
+  auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
+  auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
+
+  // Batched matrix multiply
+  auto filterType = cast<ShapedType>(transformedFilter.getType());
+  auto filterShape = filterType.getShape();
+  auto inputType = cast<ShapedType>(transformedInput.getType());
+  auto inputElemType = inputType.getElementType();
+  auto inputShape = inputType.getShape();
+
+  auto matmulType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
+       inputShape[4], filterShape[5]},
+      inputElemType);
+  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                inputElemType);
+
+  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+      ValueRange{init});
+
+  // Expand matmul result
+  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+  auto expandType =
+      RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
+                             inputShape[3], inputShape[4], filterShape[5]},
+                            inputElemType);
+  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+      loc, expandType, matmulOp.getResult(0), reassociation);
+  return expandOutput;
+}
+
+Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
+                            RankedTensorType alignedType) {
+  Value alignedInput = rewriter.create<tensor::EmptyOp>(
+      loc, alignedType.getShape(), alignedType.getElementType());
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  SmallVector<OpFoldResult, 4> sizes;
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[0]));
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[1]));
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[2]));
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[3]));
+
+  return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput,
+                                                offsets, sizes, strides);
+}
+
+Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
+                               Value value, RankedTensorType extractedType) {
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto extractedShape = extractedType.getShape();
+  SmallVector<OpFoldResult, 4> sizes;
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0]));
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1]));
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2]));
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3]));
+
+  return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
+                                                 offsets, sizes, strides);
+}
+
+bool hasAllOneValues(DenseIntElementsAttr attr) {
+  return llvm::all_of(
+      attr, [](const APInt &element) { return element.getSExtValue() == 1; });
+}
+
+FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
+                                            linalg::Conv2DNhwcFhwcOp convOp,
+                                            int64_t m, int64_t r) {
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+  auto inputType = cast<ShapedType>(input.getType());
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto outputType = cast<ShapedType>(output.getType());
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  if (!hasAllOneValues(convOp.getStrides()))
+    return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
+
+  auto filterShape = filterType.getShape();
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+  auto inputShape = inputType.getShape();
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  auto outputShape = outputType.getShape();
+  int64_t outputN = outputShape[0];
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int64_t outputF = outputShape[3];
+
+  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
+  bool isSupportedFilter = false;
+  if (filterH == filterW && filterH == r)
+    isSupportedFilter = true;
+  if (filterH == r && filterW == 1)
+    isSupportedFilter = true;
+  if (filterH == 1 && filterW == r)
+    isSupportedFilter = true;
+
+  if (!isSupportedFilter)
+    return rewriter.notifyMatchFailure(
+        convOp, "only support filter (r x r), (r x 1) or (1 x r)");
+
+  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5)
+  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+      F_2_3, F_4_3, F_2_5};
+
+  TransformMapKeyTy key = {m, r};
+  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+  // If we cannot find the constant transformation matrix, it means we do
+  // not support this configuration yet.
+  if (it == validConfigs.end())
+    return failure();
+
+  // All the criterias are satisfied. We can do Winograd Conv2D.
+  Location loc = convOp.getLoc();
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  int64_t heightM = leftTransform ? m : 1;
+  int64_t widthM = rightTransform ? m : 1;
+  int64_t heightR = leftTransform ? r : 1;
+  int64_t widthR = rightTransform ? r : 1;
+
+  // --- Create operator for filter transform ---
+  Type elementType = filterType.getElementType();
+  int64_t alphaH = heightM + heightR - 1;
+  int64_t alphaW = widthM + widthR - 1;
+  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
+  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
+  auto retType = RankedTensorType::get(
+      {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType);
+  Value retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
+      loc, retType, filter, retValue, m, r);
+
+  // --- Create operator for input transform ---
+
+  // When input size - (r - 1) is not aligned with output tile size, we need to
+  // pad the input data to create the full tiles as tiling.
+  int64_t alignedInputH = tileH * heightM + (heightR - 1);
+  int64_t alignedInputW = tileW * widthM + (widthR - 1);
+  if (alignedInputH != inputH || alignedInputW != inputW) {
+    auto alignedInputType = RankedTensorType::get(
+        {inputN, alignedInputH, alignedInputW, inputC}, elementType);
+    input = insertToAlignedTensor(rewriter, loc, input, alignedInputType);
+  }
+
+  retType = RankedTensorType::get(
+      {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType);
+  retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+  auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
+      loc, retType, input, retValue, m, r);
+
+  Value matmulRet =
+      matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
+
+  // --- Create operator for output transform ---
+
+  // When output size is not aligned with output tile size, we need to pad the
+  // output buffer to insert the full tiles after tiling.
+  int64_t alignedOutputH = tileH * heightM;
+  int64_t alignedOutputW = tileW * widthM;
+  bool isOutputUnaligned =
+      ((alignedOutputH != outputH) || (alignedOutputW != outputW));
+  if (isOutputUnaligned) {
+    auto alignedOutputType = RankedTensorType::get(
+        {outputN, alignedOutputH, alignedOutputW, outputF}, elementType);
+    output = insertToAlignedTensor(rewriter, loc, output, alignedOutputType);
+    outputType = alignedOutputType;
+  }
+
+  Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
+      loc, outputType, matmulRet, output, m, r);
+
+  // When output size is not aligned with output tile size, extract the
+  // value from the padded buffer.
+  if (isOutputUnaligned) {
+    transformedOutput = extractFromAlignedTensor(
+        rewriter, loc, transformedOutput,
+        RankedTensorType::get({outputN, outputH, outputW, outputF},
+                              elementType));
+  }
+
+  rewriter.replaceOp(convOp, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+class WinogradConv2DNhwcFhwc final
+    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
+      : OpRewritePattern(context), m(m), r(r) {}
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
+      return failure();
+
+    return success();
+  }
+
+private:
+  int64_t m;
+  int64_t r;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
+}
+
+} // end namespace linalg
+} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
new file mode 100644
index 0000000000000..6cca3c602d4c0
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
+
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %2 : tensor<2x4x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = tensor.empty() : tensor<2x2x2x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x2x2x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %2 : tensor<2x2x2x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_2x2_5x5
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x1x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x1x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+  return %2 : tensor<2x1x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_1x4_1x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x1x1x6x2x5xf32>) -> tensor<1x1x1x6x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x1x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x1x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %2 : tensor<2x4x1x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_4x1_3x1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<1x1x6x1x2x5xf32>) -> tensor<1x1x6x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<2x8x8x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x8x8x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %2 : tensor<2x8x8x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_aligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = tensor.empty() : tensor<2x9x9x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x9x9x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+  return %2 : tensor<2x9x9x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x9x9x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:  %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
+// CHECK-NEXT:  %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
+// CHECK-NEXT:  %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %2 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_1
+// CHECK: linalg.conv_2d_nhwc_fhwc
+
+// -----
+
+func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %2 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_2
+// CHECK: linalg.conv_2d_nhwc_fhwc
+
+// -----
+
+func.func @conv2d_unsupported_3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<2x3x3x5xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_3
+// CHECK: linalg.conv_2d_nhwc_fhwc
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4892fa2f99a7c..12cb46a5968f1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,6 +123,10 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
+  Option<bool> testWinogradConv2D{
+      *this, "test-winograd-conv2d",
+      llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -207,6 +211,13 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyWinogradConv2D(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
+  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -231,6 +242,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
+  if (testWinogradConv2D)
+    return applyWinogradConv2D(getOperation());
 }
 
 namespace mlir {

>From bbb6542fdd7ebe713c6fe73d28314f7d6d127069 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 24 Jun 2024 11:02:19 +0100
Subject: [PATCH 2/5] Address ftynse's comments

---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  27 +--
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 111 ++++++++-----
 .../Linalg/Transforms/WinogradConv2D.cpp      | 148 +++++++++--------
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 154 ++++--------------
 4 files changed, 199 insertions(+), 241 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index de1097b6ac27b..effff83d317c1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,7 +154,8 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+def Linalg_WinogradFilterTransformOp :
+    Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
   let summary = "Winograd filter transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -174,13 +175,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
     transformation (G x g x G^T) in the Winograd Conv2D algorithm.
   }];
 
-  let arguments = (ins AnyRankedTensor:$filter,
-                       AnyRankedTensor:$output,
+  let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
+                       TensorRankOf<[AnyType], [6]>:$output,
                        I64Attr:$m,
                        I64Attr:$r
   );
 
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs TensorRankOf<[AnyType], [6]>:$result);
   let assemblyFormat = [{
     attr-dict
     `m` `(` $m `)`
@@ -192,7 +193,8 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+def Linalg_WinogradInputTransformOp :
+    Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
   let summary = "Winograd input transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -212,13 +214,13 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
     transformation (B^T x d x B) in the Winograd Conv2D algorithm.
   }];
 
-  let arguments = (ins AnyRankedTensor:$input,
-                       AnyRankedTensor:$output,
+  let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
+                       TensorRankOf<[AnyType], [6]>:$output,
                        I64Attr:$m,
                        I64Attr:$r
   );
 
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs TensorRankOf<[AnyType], [6]>:$result);
   let assemblyFormat = [{
     attr-dict
     `m` `(` $m `)`
@@ -230,7 +232,8 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+def Linalg_WinogradOutputTransformOp :
+    Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
   let summary = "Winograd output transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -250,13 +253,13 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
     transformation (A^T x y x A) in the Winograd Conv2D algorithm.
   }];
 
-  let arguments = (ins AnyRankedTensor:$value,
-                       AnyRankedTensor:$output,
+  let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
+                       TensorRankOf<[AnyType], [4]>:$output,
                        I64Attr:$m,
                        I64Attr:$r
   );
 
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs TensorRankOf<[AnyType], [4]>:$result);
   let assemblyFormat = [{
     attr-dict
     `m` `(` $m `)`
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7bf2a5bca037f..0b22df6d49829 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2740,22 +2740,17 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
 
 LogicalResult WinogradFilterTransformOp::verify() {
   auto filterType = cast<ShapedType>(getFilter().getType());
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  auto filterElemType = filterType.getElementType();
-  auto outputElemType = outputType.getElementType();
-  if (filterElemType != outputElemType) {
-    return emitOpError() << "expected element type of input " << filterElemType
-                         << " to match element type of output "
-                         << outputElemType;
-  }
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t r = getR();
 
-  unsigned filterRank = filterType.getRank();
-  if (filterRank != 4)
-    return emitOpError() << "expected rank of input is 4";
-
-  unsigned outputRank = outputType.getRank();
-  if (outputRank != 6)
-    return emitOpError() << "expected rank of output is 6";
+  if (filterH != r && filterH != 1)
+    return failure();
+  if (filterW != r && filterW != 1)
+    return failure();
+  if (filterH == 1 && filterW == 1)
+    return failure();
 
   return success();
 }
@@ -2766,22 +2761,42 @@ LogicalResult WinogradFilterTransformOp::verify() {
 
 LogicalResult WinogradInputTransformOp::verify() {
   auto inputType = cast<ShapedType>(getInput().getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
   auto outputType = cast<ShapedType>(getOutput().getType());
-  auto inputElemType = inputType.getElementType();
-  auto outputElemType = outputType.getElementType();
-  if (inputElemType != outputElemType) {
-    return emitOpError() << "expected element type of input " << inputElemType
-                         << " to match element type of output "
-                         << outputElemType;
-  }
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputTileH = outputShape[0];
+  int64_t outputTileW = outputShape[1];
+  int64_t outputH = outputShape[2];
+  int64_t outputW = outputShape[3];
+  int m = getM();
+  int r = getR();
+  bool leftTransform = inputH != 1;
+  bool rightTransform = inputW != 1;
+
+  if (!leftTransform && !rightTransform)
+    return failure();
 
-  unsigned inputRank = inputType.getRank();
-  if (inputRank != 4)
-    return emitOpError() << "expected rank of input is 4";
+  if (leftTransform) {
+    int64_t tileH = (inputH - (r - 1)) / m;
+    if (inputH != tileH * m + (r - 1))
+      return failure();
+    if (tileH != outputTileH)
+      return failure();
+    if (outputH != m + r - 1)
+      return failure();
+  }
 
-  unsigned outputRank = outputType.getRank();
-  if (outputRank != 6)
-    return emitOpError() << "expected rank of output is 6";
+  if (rightTransform) {
+    int64_t tileW = (inputW - (r - 1)) / m;
+    if (inputW != tileW * m + (r - 1))
+      return failure();
+    if (tileW != outputTileW)
+      return failure();
+    if (outputW != m + r - 1)
+      return failure();
+  }
 
   return success();
 }
@@ -2792,22 +2807,36 @@ LogicalResult WinogradInputTransformOp::verify() {
 
 LogicalResult WinogradOutputTransformOp::verify() {
   auto valueType = cast<ShapedType>(getValue().getType());
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  int64_t valueTileH = valueShape[0];
+  int64_t valueTileW = valueShape[1];
+  int64_t valueH = valueShape[2];
+  int64_t valueW = valueShape[3];
   auto outputType = cast<ShapedType>(getOutput().getType());
-  auto valueElemType = valueType.getElementType();
-  auto outputElemType = outputType.getElementType();
-  if (valueElemType != outputElemType) {
-    return emitOpError() << "expected element type of value " << valueElemType
-                         << " to match element type of output "
-                         << outputElemType;
-  }
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int m = getM();
+  int r = getR();
+  bool leftTransform = valueH != 1;
+  bool rightTransform = valueW != 1;
+
+  if (!leftTransform && !rightTransform)
+    return failure();
 
-  unsigned valueRank = valueType.getRank();
-  if (valueRank != 6)
-    return emitOpError() << "expected rank of input is 6";
+  if (leftTransform) {
+    if (valueH != m + r - 1)
+      return failure();
+    if (outputH != m * valueTileH)
+      return failure();
+  }
 
-  unsigned outputRank = outputType.getRank();
-  if (outputRank != 4)
-    return emitOpError() << "expected rank of output is 4";
+  if (rightTransform) {
+    if (valueW != m + r - 1)
+      return failure();
+    if (outputW != m * valueTileW)
+      return failure();
+  }
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 86e834d51f2fc..b5d3a0bf5ec9c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/MathExtras.h"
 
@@ -25,21 +26,26 @@ namespace {
 
 using TransformMapKeyTy = std::pair<int, int>;
 
-// We use F(m, r) to define the size of minimal filtering algorithms.
-// m is the output dimension and r is the filter dimension. We can get
-// the input dimension, alpha, from the formula, alpha = m + r - 1.
-//
-// For example, when m = 2 and r = 3, we know its input size is 4.
-// The Conv2D will operate on 4x4 input data with 3x3 filter and get
-// 2x2 output result.
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
 constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
-Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
+/// Utility function to linearize data. The input shape is
+/// [tileH, tileW, H, W, N, C] or [tileH, tileW, H, W, C, F]. The function will
+/// convert the shape to [tileH x tileW x H x W, N, C] or
+/// [tileH x tileW x H x W, C, F].
+static Value collapseData(RewriterBase &rewriter, Location loc, Value data) {
   auto type = cast<ShapedType>(data.getType());
-  auto elementType = type.getElementType();
-  auto shape = type.getShape();
+  assert(type.hasStaticShape() && "only support static shapes.");
+  Type elementType = type.getElementType();
+  ArrayRef<int64_t> shape = type.getShape();
   auto collapseType = RankedTensorType::get(
       {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
       elementType);
@@ -48,35 +54,35 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
                                                   reassociation);
 }
 
-// This function generates linalg.batch_matmul to multiply input with filter.
-// linalg.batch_matmul only supports 3-dimension data sets. We can treat
-// tileH x tileW x H x W data as the 1-dimension data array. That is to convert
-// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we
-// can convert 6-dimension input data to 3-dimension representation that is
-// suitable for linalg.batch_matmul.
-//
-// Batched matmul will do the matrix multiply with the reduction on channel.
-//
-// We get
-//
-// %collapsed_input = tensor.collapse_shape %input
-// %collapsed_filter = tensor.collapse_shape %filter
-// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
-// %expanded_ret = tensor.expand_shape %ret
-//
-// After this function, we get return value with data layout
-// (tileH, tileW, H, W, N, F).
-Value matrixMultiply(RewriterBase &rewriter, Location loc,
-                     Value transformedFilter, Value transformedInput) {
-  auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
-  auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
-
-  // Batched matrix multiply
+/// This function generates linalg.batch_matmul to multiply input with filter.
+/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
+/// tileH x tileW x H x W data as the 1-dimensional data array. That is to
+/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
+/// way, we can convert 6-dimensional inputs to 3-dimensional representation
+/// that is suitable for linalg.batch_matmul.
+///
+/// Batched matmul will do the matrix multiply with the reduction on channel.
+///
+/// We get
+///
+/// %collapsed_input = tensor.collapse_shape %input
+/// %collapsed_filter = tensor.collapse_shape %filter
+/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+/// %expanded_ret = tensor.expand_shape %ret
+///
+/// After this function, we get return value with data layout
+/// (tileH, tileW, H, W, N, F).
+static Value matrixMultiply(RewriterBase &rewriter, Location loc,
+                            Value transformedFilter, Value transformedInput) {
+  Value collapseFilter = collapseData(rewriter, loc, transformedFilter);
+  Value collapseInput = collapseData(rewriter, loc, transformedInput);
+
+  // Batched matrix multiply.
   auto filterType = cast<ShapedType>(transformedFilter.getType());
-  auto filterShape = filterType.getShape();
+  ArrayRef<int64_t> filterShape = filterType.getShape();
   auto inputType = cast<ShapedType>(transformedInput.getType());
-  auto inputElemType = inputType.getElementType();
-  auto inputShape = inputType.getShape();
+  Type inputElemType = inputType.getElementType();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
 
   auto matmulType = RankedTensorType::get(
       {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
@@ -89,7 +95,7 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc,
       loc, matmulType, ValueRange({collapseInput, collapseFilter}),
       ValueRange{init});
 
-  // Expand matmul result
+  // Expand matmul result.
   SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
   auto expandType =
       RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
@@ -100,54 +106,55 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc,
   return expandOutput;
 }
 
-Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
-                            RankedTensorType alignedType) {
+/// Create an empty tensor with alignedType and insert the value into the
+/// created empty tensor with aligned size.
+static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc,
+                                   Value value, RankedTensorType alignedType) {
   Value alignedInput = rewriter.create<tensor::EmptyOp>(
       loc, alignedType.getShape(), alignedType.getElementType());
 
-  auto zeroIndex = rewriter.getIndexAttr(0);
-  auto oneIndex = rewriter.getIndexAttr(1);
+  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+  OpFoldResult oneIndex = rewriter.getIndexAttr(1);
   SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
   SmallVector<OpFoldResult, 4> strides(4, oneIndex);
 
   auto valueType = cast<ShapedType>(value.getType());
-  auto valueShape = valueType.getShape();
-  SmallVector<OpFoldResult, 4> sizes;
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[0]));
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[1]));
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[2]));
-  sizes.emplace_back(rewriter.getIndexAttr(valueShape[3]));
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  SmallVector<OpFoldResult> sizes =
+      getAsOpFoldResult(rewriter.getI64ArrayAttr(valueShape));
 
   return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput,
                                                 offsets, sizes, strides);
 }
 
-Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
-                               Value value, RankedTensorType extractedType) {
-  auto zeroIndex = rewriter.getIndexAttr(0);
-  auto oneIndex = rewriter.getIndexAttr(1);
+/// Extract sub-tensor with extractedType from value.
+static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
+                                      Value value,
+                                      RankedTensorType extractedType) {
+  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+  OpFoldResult oneIndex = rewriter.getIndexAttr(1);
   SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
   SmallVector<OpFoldResult, 4> strides(4, oneIndex);
 
-  auto extractedShape = extractedType.getShape();
-  SmallVector<OpFoldResult, 4> sizes;
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0]));
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1]));
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2]));
-  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3]));
+  ArrayRef<int64_t> extractedShape = extractedType.getShape();
+  SmallVector<OpFoldResult> sizes =
+      getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
 
   return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
                                                  offsets, sizes, strides);
 }
 
-bool hasAllOneValues(DenseIntElementsAttr attr) {
+/// Utility function to check all values in the attribute are 1.
+static bool hasAllOneValues(DenseIntElementsAttr attr) {
   return llvm::all_of(
       attr, [](const APInt &element) { return element.getSExtValue() == 1; });
 }
 
-FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
-                                            linalg::Conv2DNhwcFhwcOp convOp,
-                                            int64_t m, int64_t r) {
+/// A helper function to convert linalg.conv_2d_nhwc_fhwc to
+/// linalg.winograd_*_transform ops.
+static FailureOr<Operation *>
+winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
+                     int64_t m, int64_t r) {
   Value input = convOp.getInputs()[0];
   Value filter = convOp.getInputs()[1];
   Value output = convOp.getOutputs()[0];
@@ -170,23 +177,23 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
   if (!hasAllOneValues(convOp.getStrides()))
     return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
 
-  auto filterShape = filterType.getShape();
+  ArrayRef<int64_t> filterShape = filterType.getShape();
   int64_t filterF = filterShape[0];
   int64_t filterH = filterShape[1];
   int64_t filterW = filterShape[2];
   int64_t filterC = filterShape[3];
-  auto inputShape = inputType.getShape();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
   int64_t inputN = inputShape[0];
   int64_t inputH = inputShape[1];
   int64_t inputW = inputShape[2];
   int64_t inputC = inputShape[3];
-  auto outputShape = outputType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
   int64_t outputN = outputShape[0];
   int64_t outputH = outputShape[1];
   int64_t outputW = outputShape[2];
   int64_t outputF = outputShape[3];
 
-  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
+  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
   bool isSupportedFilter = false;
   if (filterH == filterW && filterH == r)
     isSupportedFilter = true;
@@ -199,7 +206,7 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
     return rewriter.notifyMatchFailure(
         convOp, "only support filter (r x r), (r x 1) or (1 x r)");
 
-  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5)
+  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
   static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
       F_2_3, F_4_3, F_2_5};
 
@@ -222,7 +229,7 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
   int64_t heightR = leftTransform ? r : 1;
   int64_t widthR = rightTransform ? r : 1;
 
-  // --- Create operator for filter transform ---
+  // --- Create operation for filter transform ---
   Type elementType = filterType.getElementType();
   int64_t alphaH = heightM + heightR - 1;
   int64_t alphaW = widthM + widthR - 1;
@@ -235,7 +242,7 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
   auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
       loc, retType, filter, retValue, m, r);
 
-  // --- Create operator for input transform ---
+  // --- Create operation for input transform ---
 
   // When input size - (r - 1) is not aligned with output tile size, we need to
   // pad the input data to create the full tiles as tiling.
@@ -257,7 +264,7 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
   Value matmulRet =
       matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
 
-  // --- Create operator for output transform ---
+  // --- Create operation for output transform ---
 
   // When output size is not aligned with output tile size, we need to pad the
   // output buffer to insert the full tiles after tiling.
@@ -289,6 +296,7 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
   return transformedOutput.getDefiningOp();
 }
 
+/// A rewrite pattern for Winograd Conv2D algorithm.
 class WinogradConv2DNhwcFhwc final
     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
 public:
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
index 6cca3c602d4c0..d24e75a71ba92 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -1,24 +1,12 @@
 // RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
 
-func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %2 : tensor<2x4x4x2xf32>
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-LABEL: func.func @conv2d_4x4_3x3
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
 // CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
 // CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
 // CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
@@ -28,31 +16,19 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
 // CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
 // CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
 // CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
 // CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
 // CHECK-NEXT: }
 
 // -----
 
-func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
-  %0 = tensor.empty() : tensor<2x2x2x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x2x2x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
-  return %2 : tensor<2x2x2x2xf32>
+func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%out : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %0 : tensor<2x2x2x2xf32>
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-LABEL: func.func @conv2d_2x2_5x5
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x2x2x2xf32>
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
 // CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
 // CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
 // CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
@@ -62,31 +38,19 @@ func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>
 // CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
 // CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
 // CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
 // CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
 // CHECK-NEXT: }
 
 // -----
 
-func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x1x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x1x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
-  return %2 : tensor<2x1x4x2xf32>
+func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%out : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+  return %0 : tensor<2x1x4x2xf32>
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-LABEL: func.func @conv2d_1x4_1x3
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x1x4x2xf32>
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
 // CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32>
 // CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32>
 // CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32>
@@ -96,31 +60,19 @@ func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>
 // CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
 // CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
 // CHECK-NEXT: }
 
 // -----
 
-func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x1x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x1x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-  return %2 : tensor<2x4x1x2xf32>
+func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%out : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %0 : tensor<2x4x1x2xf32>
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-LABEL: func.func @conv2d_4x1_3x1
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) {
-// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:     linalg.yield %[[IN]] : f32
-// CHECK-NEXT:   } -> tensor<2x4x1x2xf32>
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
 // CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32>
 // CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32>
 // CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32>
@@ -130,31 +82,19 @@ func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
 // CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
 // CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
 // CHECK-NEXT: }
 
 // -----
 
-func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-  %0 = tensor.empty() : tensor<2x8x8x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x8x8x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-  return %2 : tensor<2x8x8x2xf32>
+func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-LABEL: func.func @conv2d_aligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x8x8x2xf32>
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
 // CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
 // CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
 // CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
@@ -164,31 +104,19 @@ func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf3
 // CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
 // CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
 // CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
 // CHECK-NEXT:  return %[[S8]] : tensor<2x8x8x2xf32>
 // CHECK-NEXT: }
 
 // -----
 
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-  %0 = tensor.empty() : tensor<2x9x9x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x9x9x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
-  return %2 : tensor<2x9x9x2xf32>
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+  return %0 : tensor<2x9x9x2xf32>
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
-// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:    linalg.yield %[[IN]] : f32
-// CHECK-NEXT:  } -> tensor<2x9x9x2xf32>
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
 // CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
 // CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
 // CHECK-NEXT:  %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
@@ -201,7 +129,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
 // CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
 // CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
 // CHECK-NEXT:  %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[ARG3]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
 // CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
 // CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
 // CHECK-NEXT:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
@@ -209,14 +137,9 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
 
 // -----
 
-func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %2 : tensor<2x4x4x2xf32>
+func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
 }
 
 // CHECK-LABEL: conv2d_unsupported_1
@@ -224,14 +147,9 @@ func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x
 
 // -----
 
-func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-  %0 = tensor.empty() : tensor<2x4x4x2xf32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    linalg.yield %in : f32
-  } -> tensor<2x4x4x2xf32>
-  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-  return %2 : tensor<2x4x4x2xf32>
+func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
 }
 
 // CHECK-LABEL: conv2d_unsupported_2

>From db8e7e7d1cc889dbd48c5bb926a00f72f2a21bd9 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Tue, 25 Jun 2024 14:09:04 +0100
Subject: [PATCH 3/5] Address Max191's comments

---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |   4 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  16 +-
 .../Linalg/Transforms/WinogradConv2D.cpp      | 129 ++++++++--------
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 145 +++++++++++-------
 4 files changed, 163 insertions(+), 131 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index effff83d317c1..a9007c8db3078 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -176,12 +176,12 @@ def Linalg_WinogradFilterTransformOp :
   }];
 
   let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
-                       TensorRankOf<[AnyType], [6]>:$output,
+                       TensorRankOf<[AnyType], [4]>:$output,
                        I64Attr:$m,
                        I64Attr:$r
   );
 
-  let results = (outs TensorRankOf<[AnyType], [6]>:$result);
+  let results = (outs TensorRankOf<[AnyType], [4]>:$result);
   let assemblyFormat = [{
     attr-dict
     `m` `(` $m `)`
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 0b22df6d49829..1283315f2eaef 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2766,10 +2766,10 @@ LogicalResult WinogradInputTransformOp::verify() {
   int64_t inputW = inputShape[2];
   auto outputType = cast<ShapedType>(getOutput().getType());
   ArrayRef<int64_t> outputShape = outputType.getShape();
-  int64_t outputTileH = outputShape[0];
-  int64_t outputTileW = outputShape[1];
-  int64_t outputH = outputShape[2];
-  int64_t outputW = outputShape[3];
+  int64_t outputH = outputShape[0];
+  int64_t outputW = outputShape[1];
+  int64_t outputTileH = outputShape[2];
+  int64_t outputTileW = outputShape[3];
   int m = getM();
   int r = getR();
   bool leftTransform = inputH != 1;
@@ -2808,10 +2808,10 @@ LogicalResult WinogradInputTransformOp::verify() {
 LogicalResult WinogradOutputTransformOp::verify() {
   auto valueType = cast<ShapedType>(getValue().getType());
   ArrayRef<int64_t> valueShape = valueType.getShape();
-  int64_t valueTileH = valueShape[0];
-  int64_t valueTileW = valueShape[1];
-  int64_t valueH = valueShape[2];
-  int64_t valueW = valueShape[3];
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  int64_t valueTileH = valueShape[2];
+  int64_t valueTileW = valueShape[3];
   auto outputType = cast<ShapedType>(getOutput().getType());
   ArrayRef<int64_t> outputShape = outputType.getShape();
   int64_t outputH = outputShape[1];
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index b5d3a0bf5ec9c..6b46f9e07abf8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -37,23 +37,6 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
-/// Utility function to linearize data. The input shape is
-/// [tileH, tileW, H, W, N, C] or [tileH, tileW, H, W, C, F]. The function will
-/// convert the shape to [tileH x tileW x H x W, N, C] or
-/// [tileH x tileW x H x W, C, F].
-static Value collapseData(RewriterBase &rewriter, Location loc, Value data) {
-  auto type = cast<ShapedType>(data.getType());
-  assert(type.hasStaticShape() && "only support static shapes.");
-  Type elementType = type.getElementType();
-  ArrayRef<int64_t> shape = type.getShape();
-  auto collapseType = RankedTensorType::get(
-      {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
-      elementType);
-  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
-  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
-                                                  reassociation);
-}
-
 /// This function generates linalg.batch_matmul to multiply input with filter.
 /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
 /// tileH x tileW x H x W data as the 1-dimensional data array. That is to
@@ -73,58 +56,78 @@ static Value collapseData(RewriterBase &rewriter, Location loc, Value data) {
 /// After this function, we get return value with data layout
 /// (tileH, tileW, H, W, N, F).
 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
-                            Value transformedFilter, Value transformedInput) {
-  Value collapseFilter = collapseData(rewriter, loc, transformedFilter);
-  Value collapseInput = collapseData(rewriter, loc, transformedInput);
-
-  // Batched matrix multiply.
+                            Value transformedFilter, Value transformedInput,
+                            Type outputElementType) {
+  // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
   auto filterType = cast<ShapedType>(transformedFilter.getType());
+  assert(filterType.hasStaticShape() && "only support static shapes.");
   ArrayRef<int64_t> filterShape = filterType.getShape();
+  Type filterElementType = filterType.getElementType();
+  auto filterReassocType = RankedTensorType::get(
+      {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
+      filterElementType);
+  SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
+  Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, filterReassocType, transformedFilter, filterReassoc);
+
+  // Convert (alphaH, alphaW, tileH, tileW, N, C) to
+  // (alphaH x alphaW, tileH x tileW x N, C) for input.
   auto inputType = cast<ShapedType>(transformedInput.getType());
-  Type inputElemType = inputType.getElementType();
+  assert(inputType.hasStaticShape() && "only support static shapes.");
   ArrayRef<int64_t> inputShape = inputType.getShape();
+  Type inputElementType = inputType.getElementType();
+  auto inputReassocType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1],
+       inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
+      inputElementType);
+  SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
+  Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, inputReassocType, transformedInput, inputReassoc);
 
+  // Batched matrix multiply.
   auto matmulType = RankedTensorType::get(
-      {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
-       inputShape[4], filterShape[5]},
-      inputElemType);
+      {inputShape[0] * inputShape[1],
+       inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
+      outputElementType);
   Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                inputElemType);
+                                                outputElementType);
 
   auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
       loc, matmulType, ValueRange({collapseInput, collapseFilter}),
       ValueRange{init});
 
-  // Expand matmul result.
-  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
-  auto expandType =
+  // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
+  // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
+  SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
+  auto outputReassocType =
       RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
-                             inputShape[3], inputShape[4], filterShape[5]},
-                            inputElemType);
+                             inputShape[3], inputShape[4], filterShape[3]},
+                            outputElementType);
   auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
-      loc, expandType, matmulOp.getResult(0), reassociation);
+      loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
   return expandOutput;
 }
 
 /// Create an empty tensor with alignedType and insert the value into the
 /// created empty tensor with aligned size.
 static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc,
-                                   Value value, RankedTensorType alignedType) {
-  Value alignedInput = rewriter.create<tensor::EmptyOp>(
-      loc, alignedType.getShape(), alignedType.getElementType());
-
+                                   Value value,
+                                   ArrayRef<int64_t> alignedShape) {
   OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
-  OpFoldResult oneIndex = rewriter.getIndexAttr(1);
-  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
-  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
-
   auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
   ArrayRef<int64_t> valueShape = valueType.getShape();
-  SmallVector<OpFoldResult> sizes =
-      getAsOpFoldResult(rewriter.getI64ArrayAttr(valueShape));
-
-  return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput,
-                                                offsets, sizes, strides);
+  SmallVector<OpFoldResult, 6> lowIndices(alignedShape.size(), zeroIndex);
+  SmallVector<OpFoldResult, 6> highIndices;
+  for (unsigned i = 0; i < alignedShape.size(); ++i) {
+    highIndices.emplace_back(
+        rewriter.getIndexAttr(alignedShape[i] - valueShape[i]));
+  }
+  auto alignedType = RankedTensorType::get(alignedShape, elementType);
+  Value pad_value = rewriter.create<arith::ConstantOp>(
+      loc, elementType, rewriter.getZeroAttr(elementType));
+  return rewriter.create<tensor::PadOp>(loc, alignedType, value, lowIndices,
+                                        highIndices, pad_value);
 }
 
 /// Extract sub-tensor with extractedType from value.
@@ -230,15 +233,15 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   int64_t widthR = rightTransform ? r : 1;
 
   // --- Create operation for filter transform ---
-  Type elementType = filterType.getElementType();
+  Type filterElementType = filterType.getElementType();
   int64_t alphaH = heightM + heightR - 1;
   int64_t alphaW = widthM + widthR - 1;
   int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
   int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
-  auto retType = RankedTensorType::get(
-      {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType);
-  Value retValue =
-      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+  auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
+                                       filterElementType);
+  Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
+                                                    filterElementType);
   auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
       loc, retType, filter, retValue, m, r);
 
@@ -246,23 +249,24 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
 
   // When input size - (r - 1) is not aligned with output tile size, we need to
   // pad the input data to create the full tiles as tiling.
+  Type inputElementType = inputType.getElementType();
   int64_t alignedInputH = tileH * heightM + (heightR - 1);
   int64_t alignedInputW = tileW * widthM + (widthR - 1);
   if (alignedInputH != inputH || alignedInputW != inputW) {
-    auto alignedInputType = RankedTensorType::get(
-        {inputN, alignedInputH, alignedInputW, inputC}, elementType);
-    input = insertToAlignedTensor(rewriter, loc, input, alignedInputType);
+    input = insertToAlignedTensor(
+        rewriter, loc, input, {inputN, alignedInputH, alignedInputW, inputC});
   }
 
   retType = RankedTensorType::get(
-      {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType);
-  retValue =
-      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+      {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
+  retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
+                                              inputElementType);
   auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
       loc, retType, input, retValue, m, r);
 
-  Value matmulRet =
-      matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
+  Type outputElementType = outputType.getElementType();
+  Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
+                                   transformedInput, outputElementType);
 
   // --- Create operation for output transform ---
 
@@ -274,8 +278,9 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
       ((alignedOutputH != outputH) || (alignedOutputW != outputW));
   if (isOutputUnaligned) {
     auto alignedOutputType = RankedTensorType::get(
-        {outputN, alignedOutputH, alignedOutputW, outputF}, elementType);
-    output = insertToAlignedTensor(rewriter, loc, output, alignedOutputType);
+        {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
+    output = insertToAlignedTensor(rewriter, loc, output,
+                                   alignedOutputType.getShape());
     outputType = alignedOutputType;
   }
 
@@ -288,7 +293,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
     transformedOutput = extractFromAlignedTensor(
         rewriter, loc, transformedOutput,
         RankedTensorType::get({outputN, outputH, outputW, outputF},
-                              elementType));
+                              outputElementType));
   }
 
   rewriter.replaceOp(convOp, transformedOutput);
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
index d24e75a71ba92..ec11a6ef8fbee 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -7,16 +7,16 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
 
 // CHECK-LABEL: func.func @conv2d_4x4_3x3
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
 // CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
 // CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
 // CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
 // CHECK-NEXT: }
 
@@ -29,16 +29,16 @@ func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>
 
 // CHECK-LABEL: func.func @conv2d_2x2_5x5
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
 // CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
 // CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
 // CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
 // CHECK-NEXT: }
 
@@ -51,16 +51,16 @@ func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>
 
 // CHECK-LABEL: func.func @conv2d_1x4_1x3
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x1x1x6x2x5xf32>) -> tensor<1x1x1x6x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
 // CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
 // CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
 // CHECK-NEXT: }
 
@@ -73,16 +73,16 @@ func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
 
 // CHECK-LABEL: func.func @conv2d_4x1_3x1
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<1x1x6x1x2x5xf32>) -> tensor<1x1x6x1x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
 // CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
 // CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
 // CHECK-NEXT: }
 
@@ -95,48 +95,75 @@ func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf3
 
 // CHECK-LABEL: func.func @conv2d_aligned
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
 // CHECK-NEXT:  return %[[S8]] : tensor<2x8x8x2xf32>
 // CHECK-NEXT: }
 
 // -----
 
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
-  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
   return %0 : tensor<2x9x9x2xf32>
 }
 
 // CHECK-LABEL: func.func @conv2d_unaligned
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT:  %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
-// CHECK-NEXT:  %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
-// CHECK-NEXT:  %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[ARG3]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT:  ^bb0
+// CHECK-NEXT:    tensor.yield %[[CST]] : f32
+// CHECK-NEXT:  } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK-NEXT:  %[[PADDED_1:.*]] = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT:  ^bb0
+// CHECK-NEXT:    tensor.yield %[[CST]] : f32
+// CHECK-NEXT:  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
 // CHECK-NEXT:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
 // CHECK-NEXT: }
 
 // -----
 
+func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf16>, tensor<2x3x3x5xf16>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_type_promotion
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:   return %[[S6]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
 func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
   %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
   return %0 : tensor<2x4x4x2xf32>

>From f018ec0b6db2e67d41c4036779ead529cde6d5ff Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Sat, 29 Jun 2024 15:06:35 +0100
Subject: [PATCH 4/5] Add more tests in Linalg/roundtrip.mlir and
 Linalg/invalid.mlir

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp |  26 +++---
 mlir/test/Dialect/Linalg/invalid.mlir    | 103 +++++++++++++++++++++++
 mlir/test/Dialect/Linalg/roundtrip.mlir  |  21 +++++
 3 files changed, 137 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1283315f2eaef..1f50f2370facb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2746,11 +2746,11 @@ LogicalResult WinogradFilterTransformOp::verify() {
   int64_t r = getR();
 
   if (filterH != r && filterH != 1)
-    return failure();
+    return emitOpError("expect filter height either equals to r or 1");
   if (filterW != r && filterW != 1)
-    return failure();
+    return emitOpError("expect filter width either equals to r or 1");
   if (filterH == 1 && filterW == 1)
-    return failure();
+    return emitOpError("expect either filter height or width equals to r");
 
   return success();
 }
@@ -2781,21 +2781,21 @@ LogicalResult WinogradInputTransformOp::verify() {
   if (leftTransform) {
     int64_t tileH = (inputH - (r - 1)) / m;
     if (inputH != tileH * m + (r - 1))
-      return failure();
+      return emitOpError("input height cannot be tiled in full tile size");
     if (tileH != outputTileH)
-      return failure();
+      return emitOpError("number of output height tiles is not correct");
     if (outputH != m + r - 1)
-      return failure();
+      return emitOpError("expect output height equals to tile size");
   }
 
   if (rightTransform) {
     int64_t tileW = (inputW - (r - 1)) / m;
     if (inputW != tileW * m + (r - 1))
-      return failure();
+      return emitOpError("input width cannot be tiled in full tile size");
     if (tileW != outputTileW)
-      return failure();
+      return emitOpError("number of output width tiles is not correct");
     if (outputW != m + r - 1)
-      return failure();
+      return emitOpError("expect output width equals to tile size");
   }
 
   return success();
@@ -2826,16 +2826,16 @@ LogicalResult WinogradOutputTransformOp::verify() {
 
   if (leftTransform) {
     if (valueH != m + r - 1)
-      return failure();
+      return emitOpError("expect input height equals to input tile size");
     if (outputH != m * valueTileH)
-      return failure();
+      return emitOpError("expect output height aligned to output tile size");
   }
 
   if (rightTransform) {
     if (valueW != m + r - 1)
-      return failure();
+      return emitOpError("expect input width equals to input tile size");
     if (outputW != m * valueTileW)
-      return failure();
+      return emitOpError("expect output width aligned to output tile size");
   }
 
   return success();
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 44c81c31ace0f..e54060d7f8987 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -780,3 +780,106 @@ func.func @mixed_semantics(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>, %c: memref<
   return
 }
 
+// -----
+
+func.func @winograd_filter_transform_height(%arg0: tensor<2x4x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+  // expected-error @+1 {{expect filter height either equals to r or 1}}
+  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x4x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  return %0 : tensor<6x6x5x2xf32>
+}
+
+// -----
+
+func.func @winograd_filter_transform_width(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+  // expected-error @+1 {{expect filter width either equals to r or 1}}
+  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x4x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  return %0 : tensor<6x6x5x2xf32>
+}
+
+// -----
+
+func.func @winograd_filter_transform(%arg0: tensor<2x1x1x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+  // expected-error @+1 {{expect either filter height or width equals to r}}
+  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x1x1x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  return %0 : tensor<6x6x5x2xf32>
+}
+
+// -----
+
+func.func @winograd_input_transform_height(%arg0: tensor<2x13x14x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> {
+  // expected-error @+1 {{input height cannot be tiled in full tile size}}
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x13x14x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+  return %0 : tensor<6x6x3x3x2x5xf32>
+}
+
+// -----
+
+func.func @winograd_input_transform_width(%arg0: tensor<2x14x13x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> {
+  // expected-error @+1 {{input width cannot be tiled in full tile size}}
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x13x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+  return %0 : tensor<6x6x3x3x2x5xf32>
+}
+
+// -----
+
+func.func @winograd_input_transform_output_tileH(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32> {
+  // expected-error @+1 {{number of output height tiles is not correct}}
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32>
+  return %0 : tensor<6x6x2x3x2x5xf32>
+}
+
+// -----
+
+func.func @winograd_input_transform_output_tileW(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32> {
+  // expected-error @+1 {{number of output width tiles is not correct}}
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32>
+  return %0 : tensor<6x6x3x2x2x5xf32>
+}
+
+// -----
+
+func.func @winograd_input_transform_output_height(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32> {
+  // expected-error @+1 {{expect output height equals to tile size}}
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32>
+  return %0 : tensor<5x6x3x3x2x5xf32>
+}
+
+// -----
+
+func.func @winograd_input_transform_output_width(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32> {
+  // expected-error @+1 {{expect output width equals to tile size}}
+  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32>
+  return %0 : tensor<6x5x3x3x2x5xf32>
+}
+
+// -----
+
+func.func @winograd_output_transform_input_height(%arg0: tensor<5x6x3x3x2x2xf32>, %arg1: tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> {
+  // expected-error @+1 {{expect input height equals to input tile size}}
+  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<5x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  return %0 : tensor<2x12x12x2xf32>
+}
+
+// -----
+
+func.func @winograd_output_transform_input_width(%arg0: tensor<6x5x3x3x2x2xf32>, %arg1: tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> {
+  // expected-error @+1 {{expect input width equals to input tile size}}
+  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x5x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  return %0 : tensor<2x12x12x2xf32>
+}
+
+// -----
+
+func.func @winograd_output_transform_output_height(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32> {
+  // expected-error @+1 {{expect output height aligned to output tile size}}
+  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32>
+  return %0 : tensor<2x11x12x2xf32>
+}
+
+// -----
+
+func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32> {
+  // expected-error @+1 {{expect output width aligned to output tile size}}
+  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
+  return %0 : tensor<2x12x11x2xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index b422066aade64..49fbe13405719 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -613,3 +613,24 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
 // CHECK-SAME:     tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
 // CHECK:        return %[[D1]] : tensor<2x16x32xf32>
 // CHECK:      }
+
+// -----
+
+func.func @winograd(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<6x6x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %2 = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%2 : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+  %4 = tensor.empty() : tensor<36x2x2xf32>
+  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %6 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func @winograd
+// CHECK:         linalg.winograd_filter_transform m(4) r(3)
+// CHECK:         linalg.winograd_input_transform m(4) r(3)
+// CHECK:         linalg.winograd_output_transform m(4) r(3)

>From 11a4ee23a67e72d01f0f3ec8bcf397064aeb0e61 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Sat, 29 Jun 2024 20:38:43 +0100
Subject: [PATCH 5/5] Address more comments

---
 .../Linalg/Transforms/WinogradConv2D.cpp      | 31 ++++++++-----------
 1 file changed, 13 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 6b46f9e07abf8..351549bf2b434 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -110,24 +111,16 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
 
 /// Create an empty tensor with alignedType and insert the value into the
 /// created empty tensor with aligned size.
-static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc,
-                                   Value value,
-                                   ArrayRef<int64_t> alignedShape) {
-  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
+                                Value value, ArrayRef<int64_t> alignedShape) {
   auto valueType = cast<ShapedType>(value.getType());
   Type elementType = valueType.getElementType();
-  ArrayRef<int64_t> valueShape = valueType.getShape();
-  SmallVector<OpFoldResult, 6> lowIndices(alignedShape.size(), zeroIndex);
-  SmallVector<OpFoldResult, 6> highIndices;
-  for (unsigned i = 0; i < alignedShape.size(); ++i) {
-    highIndices.emplace_back(
-        rewriter.getIndexAttr(alignedShape[i] - valueShape[i]));
-  }
   auto alignedType = RankedTensorType::get(alignedShape, elementType);
-  Value pad_value = rewriter.create<arith::ConstantOp>(
+  Value padValue = rewriter.create<arith::ConstantOp>(
       loc, elementType, rewriter.getZeroAttr(elementType));
-  return rewriter.create<tensor::PadOp>(loc, alignedType, value, lowIndices,
-                                        highIndices, pad_value);
+
+  return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value,
+                                       padValue, false);
 }
 
 /// Extract sub-tensor with extractedType from value.
@@ -165,6 +158,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   auto filterType = cast<ShapedType>(filter.getType());
   auto outputType = cast<ShapedType>(output.getType());
 
+  // TODO: Should we support dynamic shapes?
   if (!inputType.hasStaticShape())
     return rewriter.notifyMatchFailure(convOp,
                                        "expected a static shape for the input");
@@ -253,8 +247,8 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   int64_t alignedInputH = tileH * heightM + (heightR - 1);
   int64_t alignedInputW = tileW * widthM + (widthR - 1);
   if (alignedInputH != inputH || alignedInputW != inputW) {
-    input = insertToAlignedTensor(
-        rewriter, loc, input, {inputN, alignedInputH, alignedInputW, inputC});
+    input = padToAlignedTensor(rewriter, loc, input,
+                               {inputN, alignedInputH, alignedInputW, inputC});
   }
 
   retType = RankedTensorType::get(
@@ -279,8 +273,8 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   if (isOutputUnaligned) {
     auto alignedOutputType = RankedTensorType::get(
         {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
-    output = insertToAlignedTensor(rewriter, loc, output,
-                                   alignedOutputType.getShape());
+    output =
+        padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
     outputType = alignedOutputType;
   }
 
@@ -327,6 +321,7 @@ class WinogradConv2DNhwcFhwc final
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
+  // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
   patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
 }
 



More information about the Mlir-commits mailing list