[Mlir-commits] [mlir] [mlir][linalg] Implement Winograd Conv2D. (PR #94470)

Hsiangkai Wang llvmlistbot at llvm.org
Thu Jun 6 14:47:04 PDT 2024


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/94470

>From e11fa7a96f1eb794acd0a0ffdf7ec8191bba8fd3 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 17 May 2024 14:55:15 +0100
Subject: [PATCH] [mlir][linalg] Implement Winograd Conv2D.

This patch implements the Winograd Conv2D algorithm. It supports several
configurations of Winograd Conv2D, including F(2, 3), F(4, 3) and F(2,
5). These configurations show that the implementation can support
different kernel size (3 and 5) and different output size (2 and 4).
Besides symetric kernel size 3x3 and 5x5, this patch also supports 1x3,
3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  38 +
 .../Dialect/Linalg/Transforms/Transforms.h    |   8 +
 .../TransformOps/LinalgTransformOps.cpp       |  25 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 853 ++++++++++++++++++
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 570 ++++++++++++
 .../Dialect/Linalg/winograd-tiled-conv.mlir   | 116 +++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  12 +
 8 files changed, 1623 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir
 create mode 100644 mlir/test/Dialect/Linalg/winograd-tiled-conv.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..9917f8dbd19b8 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,42 @@ def MapCopyToThreadsOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+    "structured.winograd_conv2d",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Use Winograd Conv2D algorithm to compute Conv2D.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    convolution.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..33a9b2aae0762 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,11 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
                                             linalg::BatchMatmulOp op,
                                             bool transposeLHS = true);
 
+/// Convert linalg.conv_2d_nhwc_fhwc to a sequence of operations as Winograd
+/// Conv2D algorithm.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
@@ -1692,6 +1697,9 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm.
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3a..caca623edad83 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+            return winogradConv2D(rewriter, op);
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
+  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..bdaef82318af1
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,853 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// Implement Winograd Conv2D algorithm. The implementation is based on the
+/// paper: Fast Algorithms for Convolutional Neural Networks
+/// (https://arxiv.org/abs/1509.09308)
+///
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dimension, is
+//
+// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+//
+// g is filter and d is input data. We need to prepare 6 constant
+// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+//
+// The following tables define these constant transformation matrices for
+// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+   -1,     0,   0,
+ 1./2, -1./2, 1./2,
+ 1./2,  1./2, 1./2,
+    0,     0,    1
+};
+
+constexpr float GT_2x2_3x3[] = {
+   -1,  1./2, 1./2, 0,
+    0, -1./2, 1./2, 0,
+    0,  1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+   -1,    0,   1,   0,
+    0,   -1,   1,   0,
+    0,    1,   1,   0,
+    0,   -1,   0,   1
+};
+
+constexpr float B_2x2_3x3[] = {
+   -1,    0,   0,   0,
+    0,   -1,   1,  -1,
+    1,    1,   1,   0,
+    0,    0,   0,   1
+};
+
+constexpr float AT_2x2_3x3[] = {
+    1,    1,   1,   0,
+    0,   -1,   1,   1
+};
+
+constexpr float A_2x2_3x3[] = {
+    1,    0,
+    1,   -1,
+    1,    1,
+    0,    1
+};
+
+constexpr float G_4x4_3x3[] = {
+     1,     0,     0,
+ -1./3,  1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6,  1./3,
+ 1./12,  1./6,  1./3,
+     0,     0,     1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1,  -1./3, -1./3, 1./12, 1./12, 0,
+ 0,   1./3, -1./3, -1./6,  1./6, 0,
+ 0,  -1./3, -1./3,  1./3,  1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4,     0, -5./16,      0, 1./16,     0,
+    0,  1./4,  -1./4, -1./16, 1./16,     0,
+    0, -1./4,  -1./4,  1./16, 1./16,     0,
+    0,  1./4,  -1./8,  -1./4,  1./8,     0,
+    0, -1./4,  -1./8,   1./4,  1./8,     0,
+    0,  1./4,      0, -5./16,     0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+   1./4,      0,     0,     0,     0,      0,
+      0,   1./4, -1./4,  1./4, -1./4,   1./4,
+ -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
+      0, -1./16, 1./16, -1./4,  1./4, -5./16,
+  1./16,  1./16, 1./16,  1./8,  1./8,      0,
+      0,      0,     0,     0,     0,  1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8,  1./4, 1./4,  1./8, 1./8,    0,
+    0, -1./4, 1./4, -1./4, 1./4,    0,
+    0,  1./4, 1./4,  1./2, 1./2,    0,
+    0, -1./4, 1./4,    -1,    1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+  1./8,     0,    0,     0,
+  1./4, -1./4, 1./4, -1./4,
+  1./4,  1./4, 1./4,  1./4,
+  1./8, -1./4, 1./2,    -1,
+  1./8,  1./4, 1./2,     1,
+     0,     0,    0,  1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+     1,     0,      0,      0,      0,
+  1./6, -1./6,   1./6,  -1./6,   1./6,
+ -1./6, -1./6,  -1./6,  -1./6,  -1./6,
+-4./15, 2./15, -1./15,  1./30, -1./60,
+ 1./60, 1./30,  1./15,  2./15,  4./15,
+     0,     0,      0,      0,      1
+};
+
+constexpr float GT_2x2_5x5[] = {
+   1,  1./6, -1./6, -4./15, 1./60, 0,
+   0, -1./6, -1./6,  2./15, 1./30, 0,
+   0,  1./6, -1./6, -1./15, 1./15, 0,
+   0, -1./6, -1./6,  1./30, 2./15, 0,
+   0,  1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
+    0,   1./8,  1./16,  -5./16,   1./8,    0,
+    0,  -1./8, -5./16,  -1./16,   1./8,    0,
+    0,   1./4,  -1./8,   -1./4,   1./8,    0,
+    0,  -1./8,  -1./4,    1./8,   1./4,    0,
+    0,   1./8,  3./16,   -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+   1./8,      0,      0,     0,     0,      0,
+  3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
+  -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
+ -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
+   1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
+      0,      0,      0,     0,     0,   1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+  1./2,  1, 1,  2, 1,    0,
+     0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2,    0,
+    1,   -1,
+    1,    1,
+    2,   -1,
+    1,    2,
+    0, 1./2
+};
+// clang-format on
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
+
+  return rewriter.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               const_vec));
+}
+
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+// After the transformation, we get
+//
+// scf.for %f = lo_f to hi_f step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract filter<h x w> from filter<f x h x w x c>
+//     %ret = linalg.matmul G, %extracted
+//     %ret = linalg.matmul %ret, GT
+//     %inserted = insert %ret into filter<h x w x c x f>
+//
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      int64_t outputH, int64_t outputW,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GMatrices = {
+          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+      };
+
+  // Map from (m, r) to GT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GTMatrices = {
+          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+      };
+
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+  int64_t alphaH = outputH + filterH - 1;
+  int64_t alphaW = outputW + filterW - 1;
+
+  // Return shape is <H x W x C x F>
+  auto retType =
+      RankedTensorType::get({alphaH, alphaW, filterC, filterF}, elementType);
+  Value retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value FIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (1, H, W, 1) from (F, H, W, C)
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets = {FIter, zeroIndex, zeroIndex, CIter};
+  SmallVector<OpFoldResult, 4> sizes = {oneIndex,                       // F
+                                        rewriter.getIndexAttr(filterH), // H
+                                        rewriter.getIndexAttr(filterW), // W
+                                        oneIndex};                      // C
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto targetType =
+      RankedTensorType::get({1, filterH, filterW, 1}, elementType);
+  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, targetType, filter, offsets, sizes, strides);
+
+  // Extract (H, W) from (1, H, W, 1)
+  // g = extracted (H, W)
+  auto extractFilterType =
+      RankedTensorType::get({filterH, filterW}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, extractFilterOp, extractFilterType);
+
+  TransformMapKeyTy key = {leftTransform ? outputH : outputW,
+                           leftTransform ? filterH : filterW};
+  int64_t retRows = 1;
+  Value matmulRetValue = extractFilter;
+  if (leftTransform) {
+    // Get constant transform matrix G
+    auto it = GMatrices.find(key);
+    assert(it != GMatrices.end());
+    const TransformMatrix &GMatrix = it->second;
+
+    retRows = GMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType);
+    // Multiply G x g
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix GT
+    auto it = GTMatrices.find(key);
+    assert(it != GTMatrices.end());
+    const TransformMatrix &GTMatrix = it->second;
+
+    auto matmulType =
+        RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType);
+    // Multiply u = (G x g) x GT
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert u
+  // Insert (H, W) to (H, W, 1, 1)
+  auto sliceType = RankedTensorType::get({alphaH, alphaW, 1, 1}, elementType);
+  auto init =
+      rewriter.create<tensor::EmptyOp>(loc, sliceType.getShape(), elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(
+      rewriter, loc, matmulRetValue, init);
+
+  // Insert (H, W, 1, 1) to (H, W, C, F)
+  SmallVector<OpFoldResult, 4> retOffsets = {zeroIndex, zeroIndex, CIter,
+                                             FIter};
+  SmallVector<OpFoldResult, 4> retSizes = {rewriter.getIndexAttr(alphaH),
+                                           rewriter.getIndexAttr(alphaW),
+                                           oneIndex, oneIndex};
+
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, result, iterArg, retOffsets, retSizes, strides);
+
+  rewriter.create<scf::YieldOp>(loc, insertSliceOp.getResult());
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+// This function transforms the input. The data layout of the input is NHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract input<h x w> from input<n x h x w x c>
+//     %ret = linalg.matmul BT, %extracted
+//     %ret = linalg.matmul %ret, B
+//     %inserted = insert %ret into input<h x w x n x c>
+//
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     int64_t outputH, int64_t outputW,
+                     bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to BT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BTMatrices = {
+          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+      };
+
+  // Map from (m, r) to B transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BMatrices = {
+          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+      };
+
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+
+  auto retType =
+      RankedTensorType::get({inputH, inputW, inputN, inputC}, elementType);
+  Value retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (1, H, W, 1) from (N, H, W, C)
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets = {NIter, zeroIndex, zeroIndex, CIter};
+  SmallVector<OpFoldResult, 4> sizes = {oneIndex,                      // F
+                                        rewriter.getIndexAttr(inputH), // H
+                                        rewriter.getIndexAttr(inputW), // W
+                                        oneIndex};                     // C
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto targetType = RankedTensorType::get({1, inputH, inputW, 1}, elementType);
+  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, targetType, input, offsets, sizes, strides);
+
+  // Extract (H, W) from (1, H, W, 1)
+  // d = extracted (H, W)
+  auto extractInputType = RankedTensorType::get({inputH, inputW}, elementType);
+  auto extractInput = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, extractFilterOp, extractInputType);
+
+  TransformMapKeyTy key = {leftTransform ? outputH : outputW,
+                           leftTransform ? inputH - outputH + 1
+                                         : inputW - outputW + 1};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  Value matmulRetValue = extractInput;
+  if (leftTransform) {
+    // Get constant transform matrix BT
+    auto it = BTMatrices.find(key);
+    assert(it != BTMatrices.end());
+    const TransformMatrix &BTMatrix = it->second;
+
+    retRows = BTMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value BT =
+        create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type());
+    // Multiply BT x d
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix B
+    auto it = BMatrices.find(key);
+    assert(it != BMatrices.end());
+    const TransformMatrix &BMatrix = it->second;
+
+    retCols = BMatrix.cols;
+    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+    Value B =
+        create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type());
+    // Multiply v = (BT x d) x B
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert v
+  // Insert (H, W) to (H, W, 1, 1)
+  auto sliceType = RankedTensorType::get({retRows, retCols, 1, 1}, elementType);
+  auto init =
+      rewriter.create<tensor::EmptyOp>(loc, sliceType.getShape(), elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(
+      rewriter, loc, matmulRetValue, init);
+
+  // Insert (H, W, 1, 1) to (H, W, C, F)
+  SmallVector<OpFoldResult, 4> retOffsets = {zeroIndex, zeroIndex, NIter,
+                                             CIter};
+  SmallVector<OpFoldResult, 4> retSizes = {rewriter.getIndexAttr(inputH), // H
+                                           rewriter.getIndexAttr(inputW), // W
+                                           oneIndex,                      // N
+                                           oneIndex};                     // C
+
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  Value combinedVal =
+      rewriter
+          .create<tensor::InsertSliceOp>(loc, result, iterArg, retOffsets,
+                                         retSizes, strides)
+          .getResult();
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+// This function generates linalg.batch_matmul to multiply input with filter.
+// linalg.batch_matmul only supports 3-dimension data sets. We can treat H x W
+// data as the 1-dimension data array. That is to convert [H, W, N, C] to
+// [H x W, N, C]. In this way, we can convert 4-dimension input data to
+// 3-dimension representation that is suitable for linalg.batch_matmul.
+//
+// Batched matmul will do the matrix multiply with the reduction on channel.
+//
+// We get
+//
+// %collapsed_input = tensor.collapse_shape %input
+// %collapsed_filter = tensor.collapse_shape %filter
+// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+// %expanded_ret = tensor.expand_shape %ret
+//
+// After this function, we get return value with data layout (H, W, N, F)
+//
+Value matrixMultiply(RewriterBase &rewriter, Location loc,
+                     Value transformedFilter, Value transformedInput) {
+  // Collapse transformedFilter
+  auto filterType = cast<ShapedType>(transformedFilter.getType());
+  auto filterElemType = filterType.getElementType();
+  auto filterShape = filterType.getShape();
+  auto collapseFilterType = RankedTensorType::get(
+      {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
+      filterElemType);
+  SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}};
+  auto collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, collapseFilterType, transformedFilter, reassociation);
+  // Collapse transformedInput
+  auto inputType = cast<ShapedType>(transformedInput.getType());
+  auto inputElemType = inputType.getElementType();
+  auto inputShape = inputType.getShape();
+  auto collapseInputType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1], inputShape[2], inputShape[3]},
+      inputElemType);
+  auto collapseInput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, collapseInputType, transformedInput, reassociation);
+  // Batched matrix multiply
+  auto matmulType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1], inputShape[2], filterShape[3]},
+      inputElemType);
+  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                inputElemType);
+
+  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+      ValueRange{init});
+
+  // Expand matmul result
+  auto expandType = RankedTensorType::get(
+      {inputShape[0], inputShape[1], inputShape[2], filterShape[3]},
+      inputElemType);
+  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+      loc, expandType, matmulOp.getResult(0), reassociation);
+  return expandOutput;
+}
+
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %f = lo_f to hi_f step 1
+//     %extracted = extract input<h x w> from result<h x w x n x f>
+//     %ret = linalg.matmul AT, %extracted
+//     %ret = linalg.matmul %ret, A
+//     %inserted = insert %ret into ret<n x h x w x f>
+//
+Value outputTransform(RewriterBase &rewriter, Location loc, Value output,
+                      Value value, int64_t outputH, int64_t outputW,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to AT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      ATMatrices = {
+          {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+          {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+          {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+      };
+
+  // Map from (m, r) to A transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      AMatrices = {
+          {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+          {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+          {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+      };
+
+  auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
+  auto valueShape = valueType.getShape(); // H, W, N, F
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  int64_t valueN = valueShape[2];
+  int64_t valueF = valueShape[3];
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value FIter = innerForBody->getArgument(0);
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets = {zeroIndex, zeroIndex, NIter, FIter};
+  SmallVector<OpFoldResult, 4> sizes = {rewriter.getIndexAttr(valueH), // alpha
+                                        rewriter.getIndexAttr(valueW), // alpha
+                                        oneIndex,                      // N
+                                        oneIndex};                     // F
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  // Extract (H, W, 1, 1) from (H, W, N, F)
+  auto targetType = RankedTensorType::get({valueH, valueW, 1, 1}, elementType);
+  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, targetType, value, offsets, sizes, strides);
+
+  // Extract (H, W) from (H, W, 1, 1)
+  // m = extracted (H, W)
+  auto extractValueType = RankedTensorType::get({valueH, valueW}, elementType);
+  auto extractValue = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, extractFilterOp, extractValueType);
+
+  TransformMapKeyTy key = {leftTransform ? outputH : outputW,
+                           leftTransform ? valueH - outputH + 1
+                                         : valueW - outputW + 1};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  int64_t leftScalarFactor = 1;
+  int64_t rightScalarFactor = 1;
+  Value matmulRetValue = extractValue;
+  if (leftTransform) {
+    // Get constant transform matrix AT
+    auto it = ATMatrices.find(key);
+    assert(it != ATMatrices.end());
+    const TransformMatrix &ATMatrix = it->second;
+
+    leftScalarFactor = ATMatrix.scalarFactor;
+    retRows = ATMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType);
+    // Multiply AT x m
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix T
+    auto it = AMatrices.find(key);
+    assert(it != AMatrices.end());
+    const TransformMatrix &AMatrix = it->second;
+
+    rightScalarFactor = AMatrix.scalarFactor;
+    auto matmulType =
+        RankedTensorType::get({retRows, AMatrix.cols}, elementType);
+    retCols = AMatrix.cols;
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType);
+    // Multiply y = (AT x m) x A
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Multiply scalar factor.
+  Value scalarFactor = rewriter.create<arith::ConstantOp>(
+      loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+  auto init =
+      rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType);
+
+  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
+  SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()),
+                                       identityAffineMap, identityAffineMap};
+  auto scalarMatrixOp = rewriter.create<linalg::GenericOp>(
+      loc, matmulType, ValueRange{scalarFactor, matmulRetValue},
+      ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2),
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value scalarVal = args[0];
+        Value matrixVal = args[1];
+        Value result = nestedBuilder.create<arith::MulFOp>(nestedLoc, scalarVal,
+                                                           matrixVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
+
+  // Insert slice y
+  // Insert (H, W) to (1, H, W, 1)
+  auto sliceType = RankedTensorType::get({1, retRows, retCols, 1}, elementType);
+  init =
+      rewriter.create<tensor::EmptyOp>(loc, sliceType.getShape(), elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(
+      rewriter, loc, scalarMatrixOp.getResult(0), init);
+
+  auto outputType = cast<ShapedType>(output.getType());
+  auto outputShape = outputType.getShape();
+
+  // Insert (1, H, W, 1) to (N, H, W, F)
+  SmallVector<OpFoldResult, 4> retOffsets = {NIter, zeroIndex, zeroIndex,
+                                             FIter};
+  SmallVector<OpFoldResult, 4> retSizes = {
+      oneIndex, rewriter.getIndexAttr(outputShape[1]),
+      rewriter.getIndexAttr(outputShape[2]), oneIndex};
+
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  Value combinedVal =
+      rewriter
+          .create<tensor::InsertSliceOp>(loc, result, iterArg, retOffsets,
+                                         retSizes, strides)
+          .getResult();
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
+                                            linalg::Conv2DNhwcFhwcOp convOp) {
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+
+  auto outputType = cast<ShapedType>(output.getType());
+  int64_t outputH = outputType.getShape()[1];
+  int64_t outputW = outputType.getShape()[2];
+  auto filterType = cast<ShapedType>(filter.getType());
+  int64_t filterH = filterType.getShape()[1];
+  int64_t filterW = filterType.getShape()[2];
+  auto inputType = cast<ShapedType>(input.getType());
+  int64_t inputH = inputType.getShape()[1];
+  int64_t inputW = inputType.getShape()[2];
+
+  // Check it meets the relationship between input, output, and filter size.
+  if (inputH != outputH + filterH - 1)
+    return failure();
+  if (inputW != outputW + filterW - 1)
+    return failure();
+
+  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
+  if ((outputH != outputW) && (outputH != 1 && outputW != 1))
+    return failure();
+  if ((filterH != filterW) && (filterH != 1 && filterW != 1))
+    return failure();
+
+  if ((outputH == 1 && filterH != 1) || (outputH != 1 && filterH == 1))
+    return failure();
+  if ((outputW == 1 && filterW != 1) || (outputW != 1 && filterW == 1))
+    return failure();
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = outputH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = outputW != 1;
+
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+      F_2_3, F_4_3, F_2_5};
+
+  TransformMapKeyTy key = {leftTransform ? outputH : outputW,
+                           leftTransform ? filterH : filterW};
+  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+  // If we cannot find the constant transformation matrix, it means we do
+  // not support this configuration yet.
+  if (it == validConfigs.end())
+    return failure();
+
+  // All the criterias are satisfied. We can do Winograd Conv2D.
+  Location loc = convOp.getLoc();
+  Value transformedFilter = filterTransform(
+      rewriter, loc, filter, outputH, outputW, leftTransform, rightTransform);
+  Value transformedInput = inputTransform(
+      rewriter, loc, input, outputH, outputW, leftTransform, rightTransform);
+  Value matmulRet =
+      matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
+  Value transformedOutput =
+      outputTransform(rewriter, loc, output, matmulRet, outputH, outputW,
+                      leftTransform, rightTransform);
+
+  rewriter.replaceOp(convOp, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+class WinogradConv2DNhwcFhwc final
+    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(winogradConv2DHelper(rewriter, convOp)))
+      return failure();
+
+    return success();
+  }
+};
+} // end anonymous namespace
+
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op) {
+  return winogradConv2DHelper(rewriter, op);
+}
+
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<WinogradConv2DNhwcFhwc>(context);
+}
+} // end namespace linalg
+} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
new file mode 100644
index 0000000000000..79c4ec957bd2b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -0,0 +1,570 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
+
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %2 : tensor<2x4x4x2xf32>
+}
+
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK:      func.func @conv2d_4x4_3x3
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x6x6x5xf32>, %[[ARG1:.+]]: tensor<2x3x3x5xf32>, %[[ARG2:.+]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG:     %[[CST:.+]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:     %[[CST_0:.+]] = arith.constant
+// CHECK-SAME:    [1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01]
+// CHECK-SAME:    [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01]
+// CHECK-SAME:    [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00]
+// CHECK-SAME:    [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]
+// CHECK-DAG:     %[[CST_1:.+]] = arith.constant
+// CHECK-SAME:    [1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]
+// CHECK-DAG:     %[[CST_2:.+]] = arith.constant
+// CHECK-SAME:    [2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01]
+// CHECK-SAME:    [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01]
+// CHECK-SAME:    [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]
+// CHECK-DAG:     %[[CST_3:.+]] = arith.constant
+// CHECK-SAME:    [2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]
+// CHECK-DAG:     %[[CST_4:.+]] = arith.constant
+// CHECK-SAME:    [1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]
+// CHECK-DAG:     %[[CST_5:.+]] = arith.constant
+// CHECK-SAME:    [1.000000e+00, 0.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [-0.333333343, 0.333333343, -0.333333343]
+// CHECK-SAME:    [-0.333333343, -0.333333343, -0.333333343]
+// CHECK-SAME:    [0.0833333358, -0.166666672, 0.333333343]
+// CHECK-SAME:    [0.0833333358, 0.166666672, 0.333333343]
+// CHECK-SAME:    [0.000000e+00, 0.000000e+00, 1.000000e+00]
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[S0:.+]] = tensor.empty() : tensor<2x4x4x2xf32>
+// CHECK-NEXT:    %[[S1:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:    ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:      linalg.yield %[[IN]] : f32
+// CHECK-NEXT:    } -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:    %[[S2:.+]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:    %[[S3:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S2]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:      %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:        %[[EXTRACTED_SLICE_7:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:        %[[S10:.+]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:        %[[S11:.+]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:        %[[S12:.+]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:        %[[S13:.+]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:        %[[S14:.+]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE_8:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:        scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[S4:.+]] = tensor.empty() : tensor<6x6x2x5xf32>
+// CHECK-NEXT:    %[[S5:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S4]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:      %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:        %[[EXTRACTED_SLICE_7:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:        %[[S10:.+]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:        %[[S11:.+]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:        %[[S12:.+]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:        %[[S13:.+]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:        %[[S14:.+]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE_8:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x2x5xf32>
+// CHECK-NEXT:        scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      scf.yield %[[S9]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[COLLAPSED:.+]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:    %[[COLLAPSED_6:.+]] = tensor.collapse_shape %5 {{\[}}[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:    %[[S6:.+]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:    %[[S7:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:    %[[EXPANDED:.+]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+// CHECK-NEXT:    %[[S8:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S1]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:      %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x2x2xf32> to tensor<6x6x1x1xf32>
+// CHECK-NEXT:        %[[EXTRACTED_SLICE_7:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:        %[[S10:.+]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:        %[[S11:.+]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:        %[[S12:.+]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:        %[[S13:.+]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:        %[[S14:.+]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:        %[[S15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) {
+// CHECK-NEXT:        ^bb0(%[[IN:.+]]: f32, %[[IN_9:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:          %[[S17:.+]] = arith.mulf %[[IN]], %[[IN_9]] : f32
+// CHECK-NEXT:          linalg.yield %[[S17]] : f32
+// CHECK-NEXT:        } -> tensor<4x4xf32>
+// CHECK-NEXT:        %[[S16:.+]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE_8:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:        scf.yield %[[INSERTED_SLICE_8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      scf.yield %[[S9]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  }
+
+// -----
+
+func.func @conv2d_2x2_3x3(%arg0: tensor<2x4x4x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = tensor.empty() : tensor<2x2x2x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x2x2x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x4x4x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %2 : tensor<2x2x2x2xf32>
+}
+
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK:      func.func @conv2d_2x2_3x3
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x4x4x5xf32>, %[[ARG1:.+]]: tensor<2x3x3x5xf32>, %[[ARG2:.+]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK-DAG:     %[[CST:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:     %[[CST_0:.+]] = arith.constant
+// CHECK-SAME:    [1.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [1.000000e+00, -1.000000e+00]
+// CHECK-SAME:    [1.000000e+00, 1.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 1.000000e+00]
+// CHECK-DAG:     %[[CST_1:.+]] = arith.constant
+// CHECK-SAME:    [1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -1.000000e+00, 1.000000e+00, 1.000000e+00]
+// CHECK-DAG:     %[[CST_2:.+]] = arith.constant
+// CHECK-SAME:    [-1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -1.000000e+00, 1.000000e+00, -1.000000e+00]
+// CHECK-SAME:    [1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00]
+// CHECK-DAG:     %[[CST_3:.+]] = arith.constant
+// CHECK-SAME:    [-1.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -1.000000e+00, 1.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -1.000000e+00, 0.000000e+00, 1.000000e+00]
+// CHECK-DAG:     %[[CST_4:.+]] = arith.constant
+// CHECK-SAME:    [-1.000000e+00, 5.000000e-01, 5.000000e-01, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, -5.000000e-01, 5.000000e-01, 0.000000e+00]
+// CHECK-SAME:    [0.000000e+00, 5.000000e-01, 5.000000e-01, 1.000000e+00]
+// CHECK-DAG:     %[[CST_5:.+]] = arith.constant
+// CHECK-SAME:    [-1.000000e+00, 0.000000e+00, 0.000000e+00]
+// CHECK-SAME:    [5.000000e-01, -5.000000e-01, 5.000000e-01]
+// CHECK-SAME:    [5.000000e-01, 5.000000e-01, 5.000000e-01]
+// CHECK-SAME:    [0.000000e+00, 0.000000e+00, 1.000000e+00]
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-NEXT:    %[[S0:.+]] = tensor.empty() : tensor<2x2x2x2xf32>
+// CHECK-NEXT:    %[[S1:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:    ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:      linalg.yield %[[IN]] : f32
+// CHECK-NEXT:    } -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:    %[[S2:.+]] = tensor.empty() : tensor<4x4x5x2xf32>
+// CHECK-NEXT:    %[[S3:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S2]]) -> (tensor<4x4x5x2xf32>) {
+// CHECK-NEXT:      %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<4x4x5x2xf32>) {
+// CHECK-NEXT:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:        %[[EXTRACTED_SLICE_7:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:        %[[S10:.+]] = tensor.empty() : tensor<4x3xf32>
+// CHECK-NEXT:        %[[S11:.+]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<4x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<4x3xf32>) -> tensor<4x3xf32>
+// CHECK-NEXT:        %[[S12:.+]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:        %[[S13:.+]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<4x3xf32>, tensor<3x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:        %[[S14:.+]] = tensor.empty() : tensor<4x4x1x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [4, 4, 1, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<4x4x1x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE_8:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [4, 4, 1, 1] [1, 1, 1, 1] : tensor<4x4x1x1xf32> into tensor<4x4x5x2xf32>
+// CHECK-NEXT:        scf.yield %[[INSERTED_SLICE_8]] : tensor<4x4x5x2xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      scf.yield %[[S9]] : tensor<4x4x5x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[S4:.+]] = tensor.empty() : tensor<4x4x2x5xf32>
+// CHECK-NEXT:    %[[S5:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S4]]) -> (tensor<4x4x2x5xf32>) {
+// CHECK-NEXT:      %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<4x4x2x5xf32>) {
+// CHECK-NEXT:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x4x4x5xf32> to tensor<1x4x4x1xf32>
+// CHECK-NEXT:        %[[EXTRACTED_SLICE_7:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> to tensor<4x4xf32>
+// CHECK-NEXT:        %[[S10:.+]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:        %[[S11:.+]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S10]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:        %[[S12:.+]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:        %[[S13:.+]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:        %[[S14:.+]] = tensor.empty() : tensor<4x4x1x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [4, 4, 1, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<4x4x1x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE_8:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]]] [4, 4, 1, 1] [1, 1, 1, 1] : tensor<4x4x1x1xf32> into tensor<4x4x2x5xf32>
+// CHECK-NEXT:        scf.yield %[[INSERTED_SLICE_8]] : tensor<4x4x2x5xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      scf.yield %[[S9]] : tensor<4x4x2x5xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[COLLAPSED:.+]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<4x4x5x2xf32> into tensor<16x5x2xf32>
+// CHECK-NEXT:    %[[COLLAPSED_6:.+]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<4x4x2x5xf32> into tensor<16x2x5xf32>
+// CHECK-NEXT:    %[[S6:.+]] = tensor.empty() : tensor<16x2x2xf32>
+// CHECK-NEXT:    %[[S7:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<16x2x5xf32>, tensor<16x5x2xf32>) outs(%[[S6]] : tensor<16x2x2xf32>) -> tensor<16x2x2xf32>
+// CHECK-NEXT:    %[[EXPANDED:.+]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [4, 4, 2, 2] : tensor<16x2x2xf32> into tensor<4x4x2x2xf32>
+// CHECK-NEXT:    %[[S8:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S1]]) -> (tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:      %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]]] [4, 4, 1, 1] [1, 1, 1, 1] : tensor<4x4x2x2xf32> to tensor<4x4x1x1xf32>
+// CHECK-NEXT:        %[[EXTRACTED_SLICE_7:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [4, 4, 1, 1] [1, 1, 1, 1] : tensor<4x4x1x1xf32> to tensor<4x4xf32>
+// CHECK-NEXT:        %[[S10:.+]] = tensor.empty() : tensor<2x4xf32>
+// CHECK-NEXT:        %[[S11:.+]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<2x4xf32>, tensor<4x4xf32>) outs(%[[S10]] : tensor<2x4xf32>) -> tensor<2x4xf32>
+// CHECK-NEXT:        %[[S12:.+]] = tensor.empty() : tensor<2x2xf32>
+// CHECK-NEXT:        %[[S13:.+]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<2x4xf32>, tensor<4x2xf32>) outs(%[[S12]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+// CHECK-NEXT:        %[[S14:.+]] = tensor.empty() : tensor<2x2xf32>
+// CHECK-NEXT:        %[[S15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<2x2xf32>) outs(%[[S14]] : tensor<2x2xf32>) {
+// CHECK-NEXT:        ^bb0(%[[IN:.+]]: f32, %[[IN_9:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:          %[[S17]] = arith.mulf %[[IN]], %[[IN_9]] : f32
+// CHECK-NEXT:          linalg.yield %[[S17]] : f32
+// CHECK-NEXT:        } -> tensor<2x2xf32>
+// CHECK-NEXT:        %[[S16:.+]] = tensor.empty() : tensor<1x2x2x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 2, 2, 1] [1, 1, 1, 1] : tensor<2x2xf32> into tensor<1x2x2x1xf32>
+// CHECK-NEXT:        %[[INSERTED_SLICE_8:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 2, 2, 1] [1, 1, 1, 1] : tensor<1x2x2x1xf32> into tensor<2x2x2x2xf32>
+// CHECK-NEXT:        scf.yield %[[INSERTED_SLICE_8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      scf.yield %[[S9]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT:  }
+
+// -----
+
+func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+  %0 = tensor.empty() : tensor<2x2x2x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x2x2x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+  return %2 : tensor<2x2x2x2xf32>
+}
+
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK:          func.func @conv2d_2x2_5x5
+// CHECK-SAME:     (%[[ARG0:.+]]: tensor<2x6x6x5xf32>, %[[ARG1:.+]]: tensor<2x5x5x5xf32>, %[[ARG2:.+]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK-DAG:      %[[CST:.+]] = arith.constant 2.560000e+02 : f32
+// CHECK-DAG:      %[[CST_0:.+]] = arith.constant
+// CHECK-SAME:     [5.000000e-01, 0.000000e+00]
+// CHECK-SAME:     [1.000000e+00, -1.000000e+00]
+// CHECK-SAME:     [1.000000e+00, 1.000000e+00]
+// CHECK-SAME:     [2.000000e+00, -1.000000e+00]
+// CHECK-SAME:     [1.000000e+00, 2.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 5.000000e-01]
+// CHECK-DAG:      %[[CST_1:.+]] = arith.constant
+// CHECK-SAME:     [5.000000e-01, 1.000000e+00, 1.000000e+00, 2.000000e+00, 1.000000e+00, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -1.000000e+00, 1.000000e+00, -1.000000e+00, 2.000000e+00, 5.000000e-01]
+// CHECK-DAG:      %[[CST_2:.+]] = arith.constant
+// CHECK-SAME:     [1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
+// CHECK-SAME:     [1.875000e-01, 1.250000e-01, -1.250000e-01, 2.500000e-01, -1.250000e-01, 1.250000e-01]
+// CHECK-SAME:     [-2.500000e-01, 6.250000e-02, -3.125000e-01, -1.250000e-01, -2.500000e-01, 1.875000e-01]
+// CHECK-SAME:     [-1.875000e-01, -3.125000e-01, -6.250000e-02, -2.500000e-01, 1.250000e-01, -2.500000e-01]
+// CHECK-SAME:     [1.250000e-01, 1.250000e-01, 1.250000e-01, 1.250000e-01, 2.500000e-01, -1.875000e-01]
+// CHECK-SAME:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.250000e-01]
+// CHECK-DAG:      %[[CST_3:.+]] = arith.constant
+// CHECK-SAME:     [1.250000e-01, 1.875000e-01, -2.500000e-01, -1.875000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 1.250000e-01, 6.250000e-02, -3.125000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -1.250000e-01, -3.125000e-01, -6.250000e-02, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -1.250000e-01, -2.500000e-01, 1.250000e-01, 2.500000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 1.250000e-01, 1.875000e-01, -2.500000e-01, -1.875000e-01, 1.250000e-01]
+// CHECK-DAG:      %[[CST_4:.+]] = arith.constant
+// CHECK-SAME:     [1.000000e+00, 0.166666672, -0.166666672, -0.266666681, 0.0166666675, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -0.166666672, -0.166666672, 0.13333334, 0.0333333351, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 0.166666672, -0.166666672, -0.0666666701, 0.0666666701, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -0.166666672, -0.166666672, 0.0333333351, 0.13333334, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 0.166666672, -0.166666672, -0.0166666675, 0.266666681, 1.000000e+00]
+// CHECK-DAG:      %[[CST_5:.+]] = arith.constant
+// CHECK-SAME:     [1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
+// CHECK-SAME:     [0.166666672, -0.166666672, 0.166666672, -0.166666672, 0.166666672]
+// CHECK-SAME:     [-0.166666672, -0.166666672, -0.166666672, -0.166666672, -0.166666672]
+// CHECK-SAME:     [-0.266666681, 0.13333334, -0.0666666701, 0.0333333351, -0.0166666675]
+// CHECK-SAME:     [0.0166666675, 0.0333333351, 0.0666666701, 0.13333334, 0.266666681]
+// CHECK-SAME:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00]
+// CHECK-DAG:      %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:      %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG:      %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:      %[[C0:.+]] = arith.constant 0 : index
+// CHECK:          %[[S0:.+]] = tensor.empty() : tensor<2x2x2x2xf32>
+// CHECK-NEXT:     %[[S1:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:     ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:       linalg.yield %[[IN]] : f32
+// CHECK-NEXT:     } -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:     %[[S2:.+]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:     %[[S3:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S2]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:       %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:         %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 5, 5, 1] [1, 1, 1, 1] : tensor<2x5x5x5xf32> to tensor<1x5x5x1xf32>
+// CHECK-NEXT:         %[[EXTRACTED_SLICE_7:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 5, 5, 1] [1, 1, 1, 1] : tensor<1x5x5x1xf32> to tensor<5x5xf32>
+// CHECK-NEXT:         %[[S10:.+]] = tensor.empty() : tensor<6x5xf32>
+// CHECK-NEXT:         %[[S11:.+]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x5xf32>, tensor<5x5xf32>) outs(%[[S10]] : tensor<6x5xf32>) -> tensor<6x5xf32>
+// CHECK-NEXT:         %[[S12:.+]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:         %[[S13:.+]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x5xf32>, tensor<5x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:         %[[S14:.+]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE_8:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:         scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %[[S4:.+]] = tensor.empty() : tensor<6x6x2x5xf32>
+// CHECK-NEXT:     %[[S5:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S4]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:       %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:         %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:         %[[EXTRACTED_SLICE_7:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:         %[[S10:.+]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:         %[[S11:.+]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:         %[[S12:.+]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:         %[[S13:.+]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:         %[[S14:.+]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE_8:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x2x5xf32>
+// CHECK-NEXT:         scf.yield %[[INSERTED_SLICE_8]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %[[COLLAPSED:.+]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:     %[[COLLAPSED_6:.+]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:     %[[S6:.+]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:     %[[S7:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:     %[[EXPANDED:.+]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+// CHECK-NEXT:     %[[S8:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S1]]) -> (tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:       %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<2x2x2x2xf32>) {
+// CHECK-NEXT:         %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x2x2xf32> to tensor<6x6x1x1xf32>
+// CHECK-NEXT:         %[[EXTRACTED_SLICE_7:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:         %[[S10:.+]] = tensor.empty() : tensor<2x6xf32>
+// CHECK-NEXT:         %[[S11:.+]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<2x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<2x6xf32>) -> tensor<2x6xf32>
+// CHECK-NEXT:         %[[S12:.+]] = tensor.empty() : tensor<2x2xf32>
+// CHECK-NEXT:         %[[S13:.+]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<2x6xf32>, tensor<6x2xf32>) outs(%[[S12]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+// CHECK-NEXT:         %[[S14:.+]] = tensor.empty() : tensor<2x2xf32>
+// CHECK-NEXT:         %[[S15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<2x2xf32>) outs(%[[S14]] : tensor<2x2xf32>) {
+// CHECK-NEXT:         ^bb0(%[[IN:.+]]: f32, %[[IN_9:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:           %[[S17:.+]] = arith.mulf %[[IN]], %[[IN_9]] : f32
+// CHECK-NEXT:           linalg.yield %[[S17]] : f32
+// CHECK-NEXT:         } -> tensor<2x2xf32>
+// CHECK-NEXT:         %[[S16:.+]] = tensor.empty() : tensor<1x2x2x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 2, 2, 1] [1, 1, 1, 1] : tensor<2x2xf32> into tensor<1x2x2x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE_8:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 2, 2, 1] [1, 1, 1, 1] : tensor<1x2x2x1xf32> into tensor<2x2x2x2xf32>
+// CHECK-NEXT:         scf.yield %[[INSERTED_SLICE_8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT:   }
+
+// -----
+
+func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x1x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x1x4x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+  return %2 : tensor<2x1x4x2xf32>
+}
+
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK:      func.func @conv2d_1x4_1x3
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x1x6x5xf32>, %[[ARG1:.+]]: tensor<2x1x3x5xf32>, %[[ARG2:.+]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
+// CHECK-DAG:      %[[CST:.+]] = arith.constant 3.200000e+01 : f32
+// CHECK-DAG:      %[[CST_0:.+]] = arith.constant
+// CHECK-SAME:     [1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00]
+// CHECK-SAME:     [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01]
+// CHECK-SAME:     [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01]
+// CHECK-SAME:     [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00]
+// CHECK-SAME:     [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]
+// CHECK-DAG:      %[[CST_1:.+]] = arith.constant
+// CHECK-SAME:     [2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01]
+// CHECK-SAME:     [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01]
+// CHECK-SAME:     [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]
+// CHECK-DAG:      %[[CST_2:.+]] = arith.constant
+// CHECK-SAME:     [1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]
+// CHECK-DAG:      %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:      %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG:      %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:      %[[C0:.+]] = arith.constant 0 : index
+// CHECK:          %[[S0:.+]] = tensor.empty() : tensor<2x1x4x2xf32>
+// CHECK-NEXT:     %[[S1:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) {
+// CHECK-NEXT:     ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:       linalg.yield %[[IN]] : f32
+// CHECK-NEXT:     } -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:     %[[S2:.+]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT:     %[[S3:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S2]]) -> (tensor<1x6x5x2xf32>) {
+// CHECK-NEXT:       %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<1x6x5x2xf32>) {
+// CHECK-NEXT:         %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 1, 3, 1] [1, 1, 1, 1] : tensor<2x1x3x5xf32> to tensor<1x1x3x1xf32>
+// CHECK-NEXT:         %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 1, 3, 1] [1, 1, 1, 1] : tensor<1x1x3x1xf32> to tensor<1x3xf32>
+// CHECK-NEXT:         %[[S10:.+]] = tensor.empty() : tensor<1x6xf32>
+// CHECK-NEXT:         %[[S11:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE_4]], %[[CST_2]] : tensor<1x3xf32>, tensor<3x6xf32>) outs(%[[S10]] : tensor<1x6xf32>) -> tensor<1x6xf32>
+// CHECK-NEXT:         %[[S12:.+]] = tensor.empty() : tensor<1x6x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S11]] into %[[S12]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6xf32> into tensor<1x6x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE_5:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> into tensor<1x6x5x2xf32>
+// CHECK-NEXT:         scf.yield %[[INSERTED_SLICE_5]] : tensor<1x6x5x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<1x6x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %[[S4:.+]] = tensor.empty() : tensor<1x6x2x5xf32>
+// CHECK-NEXT:     %[[S5:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S4]]) -> (tensor<1x6x2x5xf32>) {
+// CHECK-NEXT:       %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<1x6x2x5xf32>) {
+// CHECK-NEXT:         %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 1, 6, 1] [1, 1, 1, 1] : tensor<2x1x6x5xf32> to tensor<1x1x6x1xf32>
+// CHECK-NEXT:         %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 1, 6, 1] [1, 1, 1, 1] : tensor<1x1x6x1xf32> to tensor<1x6xf32>
+// CHECK-NEXT:         %[[S10:.+]] = tensor.empty() : tensor<1x6xf32>
+// CHECK-NEXT:         %[[S11:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE_4]], %[[CST_1]] : tensor<1x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<1x6xf32>) -> tensor<1x6xf32>
+// CHECK-NEXT:         %[[S12:.+]] = tensor.empty() : tensor<1x6x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S11]] into %[[S12]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6xf32> into tensor<1x6x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE_5:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> into tensor<1x6x2x5xf32>
+// CHECK-NEXT:         scf.yield %[[INSERTED_SLICE_5]] : tensor<1x6x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<1x6x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %[[COLLAPSED:.+]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:     %[[COLLAPSED_3:.+]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:     %[[S6:.+]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:     %[[S7:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_3]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:     %[[EXPANDED:.+]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x2x2xf32>
+// CHECK-NEXT:     %[[S8:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S1]]) -> (tensor<2x1x4x2xf32>) {
+// CHECK-NEXT:       %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<2x1x4x2xf32>) {
+// CHECK-NEXT:         %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x2x2xf32> to tensor<1x6x1x1xf32>
+// CHECK-NEXT:         %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> to tensor<1x6xf32>
+// CHECK-NEXT:         %[[S10:.+]] = tensor.empty() : tensor<1x4xf32>
+// CHECK-NEXT:         %[[S11:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE_4]], %[[CST_0]] : tensor<1x6xf32>, tensor<6x4xf32>) outs(%[[S10]] : tensor<1x4xf32>) -> tensor<1x4xf32>
+// CHECK-NEXT:         %[[S12:.+]] = tensor.empty() : tensor<1x4xf32>
+// CHECK-NEXT:         %[[S13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S11]] : f32, tensor<1x4xf32>) outs(%[[S12]] : tensor<1x4xf32>) {
+// CHECK-NEXT:         ^bb0(%[[IN:.+]]: f32, %[[IN_6:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:           %[[S15:.+]] = arith.mulf %[[IN]], %[[IN_6]] : f32
+// CHECK-NEXT:           linalg.yield %[[S15]] : f32
+// CHECK-NEXT:         } -> tensor<1x4xf32>
+// CHECK-NEXT:         %[[S14:.+]] = tensor.empty() : tensor<1x1x4x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [1, 1, 4, 1] [1, 1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE_5:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 1, 4, 1] [1, 1, 1, 1] : tensor<1x1x4x1xf32> into tensor<2x1x4x2xf32>
+// CHECK-NEXT:         scf.yield %[[INSERTED_SLICE_5]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     return %[[S8]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT:   }
+
+// -----
+
+func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x1x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x1x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %2 : tensor<2x4x1x2xf32>
+}
+
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK:      func.func @conv2d_4x1_3x1
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x6x1x5xf32>, %[[ARG1:.+]]: tensor<2x3x1x5xf32>, %[[ARG2:.+]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK-DAG:      %[[CST:.+]] = arith.constant 3.200000e+01 : f32
+// CHECK-DAG:      %[[CST_0:.+]] = arith.constant
+// CHECK-SAME:     [1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]
+// CHECK-DAG:      %[[CST_1:.+]] = arith.constant
+// CHECK-SAME:     [2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00]
+// CHECK-SAME:     [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]
+// CHECK-DAG:      %[[CST_2:.+]] = arith.constant
+// CHECK-SAME:     [1.000000e+00, 0.000000e+00, 0.000000e+00]
+// CHECK-SAME:     [-0.333333343, 0.333333343, -0.333333343]
+// CHECK-SAME:     [-0.333333343, -0.333333343, -0.333333343]
+// CHECK-SAME:     [0.0833333358, -0.166666672, 0.333333343]
+// CHECK-SAME:     [0.0833333358, 0.166666672, 0.333333343]
+// CHECK-SAME:     [0.000000e+00, 0.000000e+00, 1.000000e+00]
+// CHECK-DAG:      %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:      %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG:      %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:      %[[C0:.+]] = arith.constant 0 : index
+// CHECK:          %[[S0:.+]] = tensor.empty() : tensor<2x4x1x2xf32>
+// CHECK-NEXT:     %[[S1:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) {
+// CHECK-NEXT:     ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:       linalg.yield %[[IN]] : f32
+// CHECK-NEXT:     } -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:     %[[S2:.+]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT:     %[[S3:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S2]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK-NEXT:       %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK-NEXT:         %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
+// CHECK-NEXT:         %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<1x3x1x1xf32> to tensor<3x1xf32>
+// CHECK-NEXT:         %[[S10:.+]] = tensor.empty() : tensor<6x1xf32>
+// CHECK-NEXT:         %[[S11:.+]] = linalg.matmul ins(%[[CST_2]], %[[EXTRACTED_SLICE_4]] : tensor<6x3xf32>, tensor<3x1xf32>) outs(%[[S10]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK-NEXT:         %[[S12:.+]] = tensor.empty() : tensor<6x1x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S11]] into %[[S12]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE_5:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
+// CHECK-NEXT:         scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x5x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<6x1x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %[[S4:.+]] = tensor.empty() : tensor<6x1x2x5xf32>
+// CHECK-NEXT:     %[[S5:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S4]]) -> (tensor<6x1x2x5xf32>) {
+// CHECK-NEXT:       %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<6x1x2x5xf32>) {
+// CHECK-NEXT:         %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<2x6x1x5xf32> to tensor<1x6x1x1xf32>
+// CHECK-NEXT:         %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> to tensor<6x1xf32>
+// CHECK-NEXT:         %[[S10:.+]] = tensor.empty() : tensor<6x1xf32>
+// CHECK-NEXT:         %[[S11:.+]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_4]] : tensor<6x6xf32>, tensor<6x1xf32>) outs(%[[S10]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK-NEXT:         %[[S12:.+]] = tensor.empty() : tensor<6x1x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S11]] into %[[S12]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE_5:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x2x5xf32>
+// CHECK-NEXT:         scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<6x1x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %[[COLLAPSED:.+]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:     %[[COLLAPSED_3:.+]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:     %[[S6:.+]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:     %[[S7:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_3]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:     %[[EXPANDED:.+]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x2x2xf32>
+// CHECK-NEXT:     %[[S8:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.+]] = %[[S1]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK-NEXT:       %[[S9:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK-NEXT:         %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x2x2xf32> to tensor<6x1x1x1xf32>
+// CHECK-NEXT:         %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> to tensor<6x1xf32>
+// CHECK-NEXT:         %[[S10:.+]] = tensor.empty() : tensor<4x1xf32>
+// CHECK-NEXT:         %[[S11:.+]] = linalg.matmul ins(%[[CST_0]], %[[EXTRACTED_SLICE_4]] : tensor<4x6xf32>, tensor<6x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK-NEXT:         %[[S12:.+]] = tensor.empty() : tensor<4x1xf32>
+// CHECK-NEXT:         %[[S13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S11]] : f32, tensor<4x1xf32>) outs(%[[S12]] : tensor<4x1xf32>) {
+// CHECK-NEXT:         ^bb0(%[[IN:.+]]: f32, %[[IN_6:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:           %[[S15:.+]] = arith.mulf %[[IN]], %[[IN_6]] : f32
+// CHECK-NEXT:           linalg.yield %[[S15]] : f32
+// CHECK-NEXT:         } -> tensor<4x1xf32>
+// CHECK-NEXT:         %[[S14:.+]] = tensor.empty() : tensor<1x4x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<4x1xf32> into tensor<1x4x1x1xf32>
+// CHECK-NEXT:         %[[INSERTED_SLICE_5:.+]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<1x4x1x1xf32> into tensor<2x4x1x2xf32>
+// CHECK-NEXT:         scf.yield %[[INSERTED_SLICE_5]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     return %[[S8]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT:   }
diff --git a/mlir/test/Dialect/Linalg/winograd-tiled-conv.mlir b/mlir/test/Dialect/Linalg/winograd-tiled-conv.mlir
new file mode 100644
index 0000000000000..df7126f7a94f5
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-tiled-conv.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<2x8x8x2xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x8x8x2xf32>
+  %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %2 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop:2 = transform.structured.tile_using_for %0 tile_sizes [0, 4, 4, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %2 = transform.structured.winograd_conv2d %1 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d(
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<2x10x10x5xf32>,
+// CHECK-SAME:  %[[ARG1:.+]]: tensor<2x3x3x5xf32>,
+// CHECK-SAME:  %[[ARG2:.+]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-NEXT:  %[[CST:.+]] = arith.constant 1.024000e+03 : f32
+// CHECK-NEXT:  %[[CST_0:.+]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-NEXT:  %[[CST_1:.+]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-NEXT:  %[[CST_2:.+]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-NEXT:  %[[CST_3:.+]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-NEXT:  %[[CST_4:.+]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-NEXT:  %[[CST_5:.+]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-NEXT:  %[[C1:.+]] = arith.constant 1 : index
+// CHECK-NEXT:  %[[C5:.+]] = arith.constant 5 : index
+// CHECK-NEXT:  %[[C2:.+]] = arith.constant 2 : index
+// CHECK-NEXT:  %[[C4:.+]] = arith.constant 4 : index
+// CHECK-NEXT:  %[[C8:.+]] = arith.constant 8 : index
+// CHECK-NEXT:  %[[C0:.+]] = arith.constant 0 : index
+// CHECK-NEXT:  %[[S0:.+]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[S1:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[S2:.+]] = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG4:.+]] = %[[S1]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:    %[[S3:.+]] = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[ARG3]], %[[ARG5]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_6:.+]] = tensor.extract_slice %[[ARG6]][0, %[[ARG3]], %[[ARG5]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
+// CHECK-NEXT:      %[[S4:.+]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:      %[[S5:.+]] = scf.for %[[ARG7:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.+]] = %[[S4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:        %[[S11:.+]] = scf.for %[[ARG9:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:          %[[EXTRACTED_SLICE_8:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:          %[[EXTRACTED_SLICE_9:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:          %[[S12:.+]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:          %[[S13:.+]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S12]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:          %[[S14:.+]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:          %[[S15:.+]] = linalg.matmul ins(%[[S13]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:          %[[S16:.+]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:          %[[INSERTED_SLICE_10:.+]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:          %[[INSERTED_SLICE_11:.+]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][0, 0, %[[ARG9]], %[[ARG7]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:          scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:        }
+// CHECK-NEXT:        scf.yield %[[S11]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      %[[S6:.+]] = tensor.empty() : tensor<6x6x2x5xf32>
+// CHECK-NEXT:      %[[S7:.+]] = scf.for %[[ARG7:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.+]] = %[[S6]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:        %[[S11:.+]] = scf.for %[[ARG9:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor<6x6x2x5xf32>) {
+// CHECK-NEXT:          %[[EXTRACTED_SLICE_8:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:          %[[EXTRACTED_SLICE_9:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:          %[[S12:.+]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:          %[[S13:.+]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:          %[[S14:.+]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:          %[[S15:.+]] = linalg.matmul ins(%[[S13]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:          %[[S16:.+]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK-NEXT:          %[[INSERTED_SLICE_10:.+]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK-NEXT:          %[[INSERTED_SLICE_11:.+]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x2x5xf32>
+// CHECK-NEXT:          scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:        }
+// CHECK-NEXT:        scf.yield %[[S11]] : tensor<6x6x2x5xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      %[[COLLAPSED:.+]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:      %[[COLLAPSED_7:.+]] = tensor.collapse_shape %[[S7]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:      %[[S8:.+]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:      %[[S9:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_7]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S8]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:      %[[EXPANDED:.+]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2], [3]] output_shape [6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x2x2xf32>
+// CHECK-NEXT:      %[[S10:.+]] = scf.for %[[ARG7:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.+]] = %[[EXTRACTED_SLICE_6]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:        %[[S11:.+]] = scf.for %[[ARG9:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:          %[[EXTRACTED_SLICE_8:.+]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x2x2xf32> to tensor<6x6x1x1xf32>
+// CHECK-NEXT:          %[[EXTRACTED_SLICE_9:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:          %[[S12:.+]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:          %[[S13:.+]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:          %[[S14:.+]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:          %[[S15:.+]] = linalg.matmul ins(%[[S13]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:          %[[S16:.+]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:          %[[S17:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S15]] : f32, tensor<4x4xf32>) outs(%[[S16]] : tensor<4x4xf32>) {
+// CHECK-NEXT:          ^bb0(%[[IN:.+]]: f32, %[[IN_12:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:            %[[S19:.+]] = arith.mulf %[[IN]], %[[IN_12]] : f32
+// CHECK-NEXT:            linalg.yield %[[S19]] : f32
+// CHECK-NEXT:          } -> tensor<4x4xf32>
+// CHECK-NEXT:          %[[S18:.+]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:          %[[INSERTED_SLICE_10:.+]] = tensor.insert_slice %[[S17]] into %[[S18]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:          %[[INSERTED_SLICE_11:.+]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:          scf.yield %[[INSERTED_SLICE_11]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:        }
+// CHECK-NEXT:        scf.yield %[[S11]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, %[[ARG3]], %[[ARG5]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S3]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %[[S2]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4892fa2f99a7c..eba60560553a6 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,6 +123,10 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
+  Option<bool> testWinogradConv2D{
+      *this, "test-winograd-conv2d",
+      llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -207,6 +211,12 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyWinogradConv2D(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateWinogradConv2DPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -231,6 +241,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
+  if (testWinogradConv2D)
+    return applyWinogradConv2D(getOperation());
 }
 
 namespace mlir {



More information about the Mlir-commits mailing list