[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)

Hsiangkai Wang via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jun 20 05:51:46 PDT 2024


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/96183

>From 24c4f957ae673c2955fc0674f91e488813d59350 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 17:39:49 +0100
Subject: [PATCH] [mlir][linalg] Decompose winograd operators

Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
 .../Dialect/Linalg/Transforms/Transforms.h    |   3 +
 .../Linalg/Transforms/WinogradConv2D.cpp      | 773 ++++++++++++++++++
 .../Linalg/winograd-conv2d-rewrite.mlir       | 105 +++
 .../Dialect/Linalg/TestLinalgTransforms.cpp   |  11 +
 4 files changed, 892 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index da107b66257a5..bb7ec590faad0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1703,6 +1703,9 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r);
 
+/// Patterns to decompose Winograd operators.
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index d1f4be8bbf29a..d245723c85646 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -12,7 +12,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -23,6 +26,156 @@ namespace linalg {
 
 namespace {
 
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dimension, is
+//
+// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+//
+// g is filter and d is input data. We need to prepare 6 constant
+// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+//
+// The following tables define these constant transformation matrices for
+// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+   -1,     0,   0,
+ 1./2, -1./2, 1./2,
+ 1./2,  1./2, 1./2,
+    0,     0,    1
+};
+
+constexpr float GT_2x2_3x3[] = {
+   -1,  1./2, 1./2, 0,
+    0, -1./2, 1./2, 0,
+    0,  1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+   -1,    0,   1,   0,
+    0,   -1,   1,   0,
+    0,    1,   1,   0,
+    0,   -1,   0,   1
+};
+
+constexpr float B_2x2_3x3[] = {
+   -1,    0,   0,   0,
+    0,   -1,   1,  -1,
+    1,    1,   1,   0,
+    0,    0,   0,   1
+};
+
+constexpr float AT_2x2_3x3[] = {
+    1,    1,   1,   0,
+    0,   -1,   1,   1
+};
+
+constexpr float A_2x2_3x3[] = {
+    1,    0,
+    1,   -1,
+    1,    1,
+    0,    1
+};
+
+constexpr float G_4x4_3x3[] = {
+     1,     0,     0,
+ -1./3,  1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6,  1./3,
+ 1./12,  1./6,  1./3,
+     0,     0,     1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1,  -1./3, -1./3, 1./12, 1./12, 0,
+ 0,   1./3, -1./3, -1./6,  1./6, 0,
+ 0,  -1./3, -1./3,  1./3,  1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4,     0, -5./16,      0, 1./16,     0,
+    0,  1./4,  -1./4, -1./16, 1./16,     0,
+    0, -1./4,  -1./4,  1./16, 1./16,     0,
+    0,  1./4,  -1./8,  -1./4,  1./8,     0,
+    0, -1./4,  -1./8,   1./4,  1./8,     0,
+    0,  1./4,      0, -5./16,     0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+   1./4,      0,     0,     0,     0,      0,
+      0,   1./4, -1./4,  1./4, -1./4,   1./4,
+ -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
+      0, -1./16, 1./16, -1./4,  1./4, -5./16,
+  1./16,  1./16, 1./16,  1./8,  1./8,      0,
+      0,      0,     0,     0,     0,  1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8,  1./4, 1./4,  1./8, 1./8,    0,
+    0, -1./4, 1./4, -1./4, 1./4,    0,
+    0,  1./4, 1./4,  1./2, 1./2,    0,
+    0, -1./4, 1./4,    -1,    1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+  1./8,     0,    0,     0,
+  1./4, -1./4, 1./4, -1./4,
+  1./4,  1./4, 1./4,  1./4,
+  1./8, -1./4, 1./2,    -1,
+  1./8,  1./4, 1./2,     1,
+     0,     0,    0,  1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+     1,     0,      0,      0,      0,
+  1./6, -1./6,   1./6,  -1./6,   1./6,
+ -1./6, -1./6,  -1./6,  -1./6,  -1./6,
+-4./15, 2./15, -1./15,  1./30, -1./60,
+ 1./60, 1./30,  1./15,  2./15,  4./15,
+     0,     0,      0,      0,      1
+};
+
+constexpr float GT_2x2_5x5[] = {
+   1,  1./6, -1./6, -4./15, 1./60, 0,
+   0, -1./6, -1./6,  2./15, 1./30, 0,
+   0,  1./6, -1./6, -1./15, 1./15, 0,
+   0, -1./6, -1./6,  1./30, 2./15, 0,
+   0,  1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
+    0,   1./8,  1./16,  -5./16,   1./8,    0,
+    0,  -1./8, -5./16,  -1./16,   1./8,    0,
+    0,   1./4,  -1./8,   -1./4,   1./8,    0,
+    0,  -1./8,  -1./4,    1./8,   1./4,    0,
+    0,   1./8,  3./16,   -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+   1./8,      0,      0,     0,     0,      0,
+  3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
+  -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
+ -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
+   1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
+      0,      0,      0,     0,     0,   1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+  1./2,  1, 1,  2, 1,    0,
+     0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2,    0,
+    1,   -1,
+    1,    1,
+    2,   -1,
+    1,    2,
+    0, 1./2
+};
+// clang-format on
+
 using TransformMapKeyTy = std::pair<int, int>;
 
 // We use F(m, r) to define the size of minimal filtering algorithms.
@@ -36,6 +189,92 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
+
+  return rewriter.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               const_vec));
+}
+
+Value extract2DData(RewriterBase &rewriter, Location loc, Value source,
+                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+                    int64_t srcSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  auto sourceShape = sourceType.getShape();
+  int64_t height = sourceShape[heightIdx];
+  int64_t width = sourceShape[widthIdx];
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
+  offsets[outLoopIdx] = outLoopIndex;
+  offsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
+  sizes[heightIdx] = rewriter.getIndexAttr(height);
+  sizes[widthIdx] = rewriter.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
+  SmallVector<int64_t> targetShape(srcSize, 1);
+  targetShape[heightIdx] = height;
+  targetShape[widthIdx] = width;
+
+  auto targetType = RankedTensorType::get(targetShape, elementType);
+  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, targetType, source, offsets, sizes, strides);
+
+  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, extractFilterOp, extractFilterType);
+
+  return extractFilter;
+}
+
+Value insert2DData(RewriterBase &rewriter, Location loc, Value source,
+                   Value dest, Value outLoopIndex, Value inLoopIndex,
+                   int64_t height, int64_t width, int64_t outLoopIdx,
+                   int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+                   int64_t destSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  SmallVector<int64_t> sliceShape(destSize, 1);
+  sliceShape[heightIdx] = height;
+  sliceShape[widthIdx] = width;
+  auto init = rewriter.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(rewriter, loc,
+                                                                 source, init);
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
+  retOffsets[outLoopIdx] = outLoopIndex;
+  retOffsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
+  retSizes[heightIdx] = rewriter.getIndexAttr(height);
+  retSizes[widthIdx] = rewriter.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
+
+  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, result, dest, retOffsets, retSizes, strides);
+
+  return insertSliceOp;
+}
+
 Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
   auto type = cast<ShapedType>(data.getType());
   auto elementType = type.getElementType();
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
                                                   reassociation);
 }
 
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+// After the transformation, we get
+//
+// scf.for %f = lo_f to hi_f step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract filter<h x w> from filter<f x h x w x c>
+//     %ret = linalg.matmul G, %extracted
+//     %ret = linalg.matmul %ret, GT
+//     %inserted = insert %ret into filter<tile_h x tile_w x h x w x c x f>
+//
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      Value retValue, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GMatrices = {
+          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+      };
+
+  // Map from (m, r) to GT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GTMatrices = {
+          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+      };
+
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+
+  if (filterH != r && filterH != 1)
+    return Value();
+  if (filterW != r && filterW != 1)
+    return Value();
+
+  // Return shape is <H x W x C x F>
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value FIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (F, H, W, C)
+  auto extractFilter = extract2DData(
+      rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+      /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  Value matmulRetValue = extractFilter;
+  if (leftTransform) {
+    // Get constant transform matrix G
+    auto it = GMatrices.find(key);
+    if (it == GMatrices.end())
+      return Value();
+    const TransformMatrix &GMatrix = it->second;
+
+    retRows = GMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType);
+    // Multiply G x g
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix GT
+    auto it = GTMatrices.find(key);
+    if (it == GTMatrices.end())
+      return Value();
+    const TransformMatrix &GTMatrix = it->second;
+
+    auto matmulType =
+        RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType);
+    // Multiply u = (G x g) x GT
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert (H, W) to (1, 1, H, W, C, F)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  int64_t retHeight = leftTransform ? m + r - 1 : 1;
+  int64_t retWidth = rightTransform ? m + r - 1 : 1;
+  auto insertSliceOp = insert2DData(
+      rewriter, loc, matmulRetValue, iterArg, FIter, CIter, retHeight, retWidth,
+      /*outLoopIdx=*/5, /*inLoopIdx=*/4, /*heightIdx=*/2, /*widthIdx=*/3,
+      /*destSize=*/6);
+
+  rewriter.create<scf::YieldOp>(loc, insertSliceOp);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+// This function transforms the input. The data layout of the input is NHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract input<h x w> from input<n x h x w x c>
+//     %ret = linalg.matmul BT, %extracted
+//     %ret = linalg.matmul %ret, B
+//     %inserted = insert %ret into input<h x w x n x c>
+//
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     Value retValue, int64_t m, int64_t r,
+                     bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to BT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BTMatrices = {
+          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+      };
+
+  // Map from (m, r) to B transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BMatrices = {
+          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+      };
+
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (inputH != alphaH && inputH != 1)
+    return Value();
+  if (inputW != alphaW && inputW != 1)
+    return Value();
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (N, H, W, C)
+  auto extractInput = extract2DData(
+      rewriter, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+      /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  Value matmulRetValue = extractInput;
+  if (leftTransform) {
+    // Get constant transform matrix BT
+    auto it = BTMatrices.find(key);
+    if (it == BTMatrices.end())
+      return Value();
+    const TransformMatrix &BTMatrix = it->second;
+
+    retRows = BTMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value BT =
+        create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type());
+    // Multiply BT x d
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix B
+    auto it = BMatrices.find(key);
+    if (it == BMatrices.end())
+      return Value();
+    const TransformMatrix &BMatrix = it->second;
+
+    retCols = BMatrix.cols;
+    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+    Value B =
+        create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type());
+    // Multiply v = (BT x d) x B
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert v
+  // Insert (H, W) to (1, 1, H, W, N, C)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  auto combinedVal = insert2DData(
+      rewriter, loc, matmulRetValue, iterArg, NIter, CIter, retRows, retCols,
+      /*outLoopIdx=*/4, /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3,
+      /*destSize=*/6);
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
 // This function generates linalg.batch_matmul to multiply input with filter.
 // linalg.batch_matmul only supports 3-dimension data sets. We can treat
 // tileH x tileW x H x W data as the 1-dimension data array. That is to convert
@@ -100,6 +594,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc,
   return expandOutput;
 }
 
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %f = lo_f to hi_f step 1
+//     %extracted = extract input<h x w> from result<h x w x n x f>
+//     %ret = linalg.matmul AT, %extracted
+//     %ret = linalg.matmul %ret, A
+//     %inserted = insert %ret into ret<n x h x w x f>
+//
+Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
+                      Value output, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to AT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      ATMatrices = {
+          {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+          {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+          {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+      };
+
+  // Map from (m, r) to A transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      AMatrices = {
+          {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+          {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+          {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+      };
+
+  auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
+  auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F
+  int64_t valueH = valueShape[2];
+  int64_t valueW = valueShape[3];
+  int64_t valueN = valueShape[4];
+  int64_t valueF = valueShape[5];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (valueH != alphaH && valueH != 1)
+    return Value();
+  if (valueW != alphaW && valueW != 1)
+    return Value();
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value FIter = innerForBody->getArgument(0);
+
+  // Extract (H, W) from (1, 1, H, W, N, F)
+  auto extractValue = extract2DData(
+      rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/4,
+      /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, /*srcSize=*/6);
+
+  TransformMapKeyTy key = {m, r};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  int64_t leftScalarFactor = 1;
+  int64_t rightScalarFactor = 1;
+  Value matmulRetValue = extractValue;
+  if (leftTransform) {
+    // Get constant transform matrix AT
+    auto it = ATMatrices.find(key);
+    if (it == ATMatrices.end())
+      return Value();
+    const TransformMatrix &ATMatrix = it->second;
+
+    leftScalarFactor = ATMatrix.scalarFactor;
+    retRows = ATMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType);
+    // Multiply AT x m
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix T
+    auto it = AMatrices.find(key);
+    if (it == AMatrices.end())
+      return Value();
+    const TransformMatrix &AMatrix = it->second;
+
+    rightScalarFactor = AMatrix.scalarFactor;
+    auto matmulType =
+        RankedTensorType::get({retRows, AMatrix.cols}, elementType);
+    retCols = AMatrix.cols;
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType);
+    // Multiply y = (AT x m) x A
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Multiply scalar factor.
+  Value scalarFactor = rewriter.create<arith::ConstantOp>(
+      loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+  auto init =
+      rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType);
+
+  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
+  SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()),
+                                       identityAffineMap, identityAffineMap};
+  auto scalarMatrixOp = rewriter.create<linalg::GenericOp>(
+      loc, matmulType, ValueRange{scalarFactor, matmulRetValue},
+      ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2),
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value scalarVal = args[0];
+        Value matrixVal = args[1];
+        Value result = nestedBuilder.create<arith::MulFOp>(nestedLoc, scalarVal,
+                                                           matrixVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
+
+  // Insert slice y
+  // Insert (H, W) to (N, H, W, F)
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  Value combinedVal = insert2DData(rewriter, loc, scalarMatrixOp.getResult(0),
+                                   iterArg, NIter, FIter, retRows, retCols,
+                                   /*outLoopIdx=*/0,
+                                   /*inLoopIdx=*/3, /*heightIdx=*/1,
+                                   /*widthIdx=*/2, /*destSize=*/4);
+
+  rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
 Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
                             RankedTensorType alignedType) {
   Value alignedInput = rewriter.create<tensor::EmptyOp>(
@@ -289,6 +938,123 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
   return transformedOutput.getDefiningOp();
 }
 
+FailureOr<Operation *>
+decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradFilterTransformOp op) {
+  Location loc = op.getLoc();
+  Value filter = op.getFilter();
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  Value transformedFilter =
+      filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedFilter)
+    return failure();
+
+  rewriter.replaceOp(op, transformedFilter);
+
+  return transformedFilter.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
+                                      linalg::WinogradInputTransformOp op) {
+  Location loc = op.getLoc();
+  Value input = op.getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = inputH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = inputW != 1;
+  Value transformedInput =
+      inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
+                     op.getR(), leftTransform, rightTransform);
+  if (!transformedInput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedInput);
+
+  return transformedInput.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradOutputTransformOp op) {
+  Location loc = op.getLoc();
+  Value value = op.getValue();
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  int64_t valueH = valueShape[2];
+  int64_t valueW = valueShape[3];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = valueH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = valueW != 1;
+  Value transformedOutput =
+      outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedOutput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+class DecomposeWinogradFilterTransform final
+    : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradFilterTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class DecomposeWinogradInputTransform final
+    : public OpRewritePattern<linalg::WinogradInputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradInputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class DecomposeWinogradOutputTransform final
+    : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradOutputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
 class WinogradConv2DNhwcFhwc final
     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
 public:
@@ -323,5 +1089,12 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
   patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
 }
 
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<DecomposeWinogradFilterTransform>(context);
+  patterns.insert<DecomposeWinogradInputTransform>(context);
+  patterns.insert<DecomposeWinogradOutputTransform>(context);
+}
+
 } // end namespace linalg
 } // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
new file mode 100644
index 0000000000000..917d089c1981c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-decompose-winograd-ops | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x4x4x2xf32>
+  %2 = tensor.empty() : tensor<1x1x6x6x5x2xf32>
+  %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
+  %4 = tensor.empty() : tensor<1x1x6x6x2x5xf32>
+  %5 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%4 : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %5 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+  %6 = tensor.empty() : tensor<36x2x2xf32>
+  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %expanded = tensor.expand_shape %7 [[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
+  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<1x1x6x6x2x2xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %8 : tensor<2x4x4x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:   %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:   %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:   %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:   %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:    linalg.yield %[[IN]] : f32
+// CHECK-NEXT:  } -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG5]], %[[ARG3]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S4]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S1]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x2xf32> to tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:      %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:      %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:      %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:      %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) {
+// CHECK-NEXT:      ^bb0(%[[IN:.*]]: f32, %[[IN_9:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:        %[[S17:.*]] = arith.mulf %[[IN]], %[[IN_9]] : f32
+// CHECK-NEXT:        linalg.yield %[[S17]] : f32
+// CHECK-NEXT:      } -> tensor<4x4xf32>
+// CHECK-NEXT:      %[[S16:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:      %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:      scf.yield %[[INSERTED_SLICE_8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.yield %[[S9]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 12cb46a5968f1..5899f56da7345 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -127,6 +127,9 @@ struct TestLinalgTransforms
       *this, "test-winograd-conv2d",
       llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
       llvm::cl::init(false)};
+  Option<bool> testDecomposeWinogradOps{
+      *this, "test-decompose-winograd-ops",
+      llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
 };
 } // namespace
 
@@ -218,6 +221,12 @@ static void applyWinogradConv2D(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateDecomposeWinogradOpsPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -244,6 +253,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnnecessaryInputs(getOperation());
   if (testWinogradConv2D)
     return applyWinogradConv2D(getOperation());
+  if (testDecomposeWinogradOps)
+    return applyDecomposeWinogradOps(getOperation());
 }
 
 namespace mlir {



More information about the llvm-branch-commits mailing list