[Mlir-commits] [mlir] 27ee33d - [mlir][linalg] Decompose winograd operators (#96183)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 17 22:04:57 PDT 2024


Author: Hsiangkai Wang
Date: 2024-07-18T06:04:53+01:00
New Revision: 27ee33d1368b9772f75285932c00479a0fae82ee

URL: https://github.com/llvm/llvm-project/commit/27ee33d1368b9772f75285932c00479a0fae82ee
DIFF: https://github.com/llvm/llvm-project/commit/27ee33d1368b9772f75285932c00479a0fae82ee.diff

LOG: [mlir][linalg] Decompose winograd operators (#96183)

Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)

Reviewers: ftynse, Max191, GeorgeARM, nicolasvasilache, MaheshRavishankar, dcaballe, rengolin

Reviewed By: ftynse, Max191

Pull Request: https://github.com/llvm/llvm-project/pull/96183

Added: 
    mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index eac6eb4387a0f..0c7a8edff222f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1746,6 +1746,9 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r);
 
+/// Patterns to decompose Winograd operators.
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
+
 /// Adds patterns that reduce the rank of named contraction ops that have
 /// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
 /// `<corresponding linalg named op>`, `expand_shape` (if on tensors).  For example a

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 18dd4769f9a49..754f832e98eea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -12,10 +12,14 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/MathExtras.h"
 
 namespace mlir {
@@ -23,6 +27,156 @@ namespace linalg {
 
 namespace {
 
+// clang-format off
+/// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+/// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+/// m is the output dimension and r is the filter dimension, is
+///
+/// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+///
+/// g is filter and d is input data. We need to prepare 6 constant
+/// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+///
+/// The following tables define these constant transformation matrices for
+/// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+   -1,     0,   0,
+ 1./2, -1./2, 1./2,
+ 1./2,  1./2, 1./2,
+    0,     0,    1
+};
+
+constexpr float GT_2x2_3x3[] = {
+   -1,  1./2, 1./2, 0,
+    0, -1./2, 1./2, 0,
+    0,  1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+   -1,    0,   1,   0,
+    0,   -1,   1,   0,
+    0,    1,   1,   0,
+    0,   -1,   0,   1
+};
+
+constexpr float B_2x2_3x3[] = {
+   -1,    0,   0,   0,
+    0,   -1,   1,  -1,
+    1,    1,   1,   0,
+    0,    0,   0,   1
+};
+
+constexpr float AT_2x2_3x3[] = {
+    1,    1,   1,   0,
+    0,   -1,   1,   1
+};
+
+constexpr float A_2x2_3x3[] = {
+    1,    0,
+    1,   -1,
+    1,    1,
+    0,    1
+};
+
+constexpr float G_4x4_3x3[] = {
+     1,     0,     0,
+ -1./3,  1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6,  1./3,
+ 1./12,  1./6,  1./3,
+     0,     0,     1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1,  -1./3, -1./3, 1./12, 1./12, 0,
+ 0,   1./3, -1./3, -1./6,  1./6, 0,
+ 0,  -1./3, -1./3,  1./3,  1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4,     0, -5./16,      0, 1./16,     0,
+    0,  1./4,  -1./4, -1./16, 1./16,     0,
+    0, -1./4,  -1./4,  1./16, 1./16,     0,
+    0,  1./4,  -1./8,  -1./4,  1./8,     0,
+    0, -1./4,  -1./8,   1./4,  1./8,     0,
+    0,  1./4,      0, -5./16,     0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+   1./4,      0,     0,     0,     0,      0,
+      0,   1./4, -1./4,  1./4, -1./4,   1./4,
+ -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
+      0, -1./16, 1./16, -1./4,  1./4, -5./16,
+  1./16,  1./16, 1./16,  1./8,  1./8,      0,
+      0,      0,     0,     0,     0,  1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8,  1./4, 1./4,  1./8, 1./8,    0,
+    0, -1./4, 1./4, -1./4, 1./4,    0,
+    0,  1./4, 1./4,  1./2, 1./2,    0,
+    0, -1./4, 1./4,    -1,    1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+  1./8,     0,    0,     0,
+  1./4, -1./4, 1./4, -1./4,
+  1./4,  1./4, 1./4,  1./4,
+  1./8, -1./4, 1./2,    -1,
+  1./8,  1./4, 1./2,     1,
+     0,     0,    0,  1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+     1,     0,      0,      0,      0,
+  1./6, -1./6,   1./6,  -1./6,   1./6,
+ -1./6, -1./6,  -1./6,  -1./6,  -1./6,
+-4./15, 2./15, -1./15,  1./30, -1./60,
+ 1./60, 1./30,  1./15,  2./15,  4./15,
+     0,     0,      0,      0,      1
+};
+
+constexpr float GT_2x2_5x5[] = {
+   1,  1./6, -1./6, -4./15, 1./60, 0,
+   0, -1./6, -1./6,  2./15, 1./30, 0,
+   0,  1./6, -1./6, -1./15, 1./15, 0,
+   0, -1./6, -1./6,  1./30, 2./15, 0,
+   0,  1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
+    0,   1./8,  1./16,  -5./16,   1./8,    0,
+    0,  -1./8, -5./16,  -1./16,   1./8,    0,
+    0,   1./4,  -1./8,   -1./4,   1./8,    0,
+    0,  -1./8,  -1./4,    1./8,   1./4,    0,
+    0,   1./8,  3./16,   -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+   1./8,      0,      0,     0,     0,      0,
+  3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
+  -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
+ -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
+   1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
+      0,      0,      0,     0,     0,   1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+  1./2,  1, 1,  2, 1,    0,
+     0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2,    0,
+    1,   -1,
+    1,    1,
+    2,   -1,
+    1,    2,
+    0, 1./2
+};
+// clang-format on
+
 using TransformMapKeyTy = std::pair<int, int>;
 
 /// We use F(m, r) to define the size of minimal filtering algorithms.
@@ -36,6 +190,408 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
+/// Structure to keep information of constant transform matrices.
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+/// Utility function to convert constant array to arith.constant Value.
+Value create2DTransformMatrix(OpBuilder &builder, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+
+  return builder.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               constVec));
+}
+
+/// Extract height x width data from 4D tensors.
+Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
+                          Value loopNorFIndex, Value loopCorFIndex,
+                          Value heightOffset, Value widthOffset,
+                          int64_t extractHeight, int64_t extractWidth,
+                          int64_t loopNorFIdx, int64_t loopCorFIdx,
+                          int64_t heightIdx, int64_t widthIdx) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  int64_t srcSize = sourceType.getRank();
+
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> offsets;
+  offsets.resize(srcSize);
+  offsets[loopNorFIdx] = loopNorFIndex;
+  offsets[loopCorFIdx] = loopCorFIndex;
+  offsets[heightIdx] = heightOffset;
+  offsets[widthIdx] = widthOffset;
+  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
+  sizes[heightIdx] = builder.getIndexAttr(extractHeight);
+  sizes[widthIdx] = builder.getIndexAttr(extractWidth);
+  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
+
+  auto extractFilterType =
+      RankedTensorType::get({extractHeight, extractWidth}, elementType);
+  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
+      loc, extractFilterType, source, offsets, sizes, strides);
+
+  return extractFilterOp;
+}
+
+/// Extract height x width data from 6D tensors.
+Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
+                          Value tileHIndex, Value tileWIndex,
+                          Value loopNorFIndex, Value loopCorFIndex,
+                          int64_t tileHIdx, int64_t tileWIdx,
+                          int64_t loopNorFIdx, int64_t loopCorFIdx,
+                          int64_t heightIdx, int64_t widthIdx) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  auto sourceShape = sourceType.getShape();
+  int64_t srcSize = sourceType.getRank();
+  int64_t height = sourceShape[heightIdx];
+  int64_t width = sourceShape[widthIdx];
+
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
+  offsets.resize(srcSize);
+  offsets[tileHIdx] = tileHIndex;
+  offsets[tileWIdx] = tileWIndex;
+  offsets[loopNorFIdx] = loopNorFIndex;
+  offsets[loopCorFIdx] = loopCorFIndex;
+  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
+  sizes[heightIdx] = builder.getIndexAttr(height);
+  sizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
+
+  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
+      loc, extractFilterType, source, offsets, sizes, strides);
+
+  return extractFilterOp;
+}
+
+/// Insert transformed height x width data to 4D tensors which it is
+/// extracted from.
+Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
+                       Value dest, Value loopNorFIndex, Value loopCorFIndex,
+                       Value heightOffset, Value widthOffset, int64_t height,
+                       int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
+                       int64_t heightIdx, int64_t widthIdx) {
+  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> retOffsets;
+  retOffsets.resize(destSize);
+  retOffsets[loopNorFIdx] = loopNorFIndex;
+  retOffsets[loopCorFIdx] = loopCorFIndex;
+  retOffsets[heightIdx] = heightOffset;
+  retOffsets[widthIdx] = widthOffset;
+  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
+  retSizes[heightIdx] = builder.getIndexAttr(height);
+  retSizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult> strides(destSize, oneIndex);
+
+  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+      loc, source, dest, retOffsets, retSizes, strides);
+
+  return insertSliceOp;
+}
+
+/// Insert transformed height x width data to 6D tensors which it is
+/// extracted from.
+Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
+                       Value dest, Value tileHIndex, Value tileWIndex,
+                       Value loopNorFIndex, Value loopCorFIndex, int64_t height,
+                       int64_t width, int64_t tileHIdx, int64_t tileWIdx,
+                       int64_t loopNorFIdx, int64_t loopCorFIdx,
+                       int64_t heightIdx, int64_t widthIdx) {
+  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
+  retOffsets.resize(destSize);
+  retOffsets[tileHIdx] = tileHIndex;
+  retOffsets[tileWIdx] = tileWIndex;
+  retOffsets[loopNorFIdx] = loopNorFIndex;
+  retOffsets[loopCorFIdx] = loopCorFIndex;
+  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
+  retSizes[heightIdx] = builder.getIndexAttr(height);
+  retSizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult> strides(destSize, oneIndex);
+
+  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+      loc, source, dest, retOffsets, retSizes, strides);
+
+  return insertSliceOp;
+}
+
+/// This function transforms the filter. The data layout of the filter is FHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+/// After the transformation, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract filter<h x w> from filter<f x h x w x c>
+///     %ret = linalg.matmul G, %extracted
+///     %ret = linalg.matmul %ret, GT
+///     %inserted = insert %ret into filter<h x w x c x f>
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      Value retValue, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GMatrices = {
+          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+      };
+
+  // Map from (m, r) to GT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GTMatrices = {
+          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+      };
+
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+
+  if (filterH != r && filterH != 1)
+    return Value();
+  if (filterW != r && filterW != 1)
+    return Value();
+
+  Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value FIter = ivs[0];
+    Value CIter = ivs[1];
+
+    // Extract (H, W) from (F, H, W, C).
+    auto extractFilter =
+        extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
+                            zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
+                            /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    Value matmulRetValue = extractFilter;
+    if (leftTransform) {
+      // Get constant transform matrix G.
+      auto it = GMatrices.find(key);
+      if (it == GMatrices.end())
+        return {};
+      const TransformMatrix &GMatrix = it->second;
+
+      retRows = GMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+      // Multiply G x g.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix GT.
+      auto it = GTMatrices.find(key);
+      if (it == GTMatrices.end())
+        return {};
+      const TransformMatrix &GTMatrix = it->second;
+
+      auto matmulType =
+          RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+      // Multiply u = (G x g) x GT.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Insert (H, W) to (H, W, C, F).
+    int64_t retHeight = leftTransform ? m + r - 1 : 1;
+    int64_t retWidth = rightTransform ? m + r - 1 : 1;
+
+    auto insertSliceOp =
+        insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
+                         zeroIdx, zeroIdx, retHeight, retWidth,
+                         /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
+                         /*heightIdx=*/0, /*widthIdx=*/1);
+
+    return {insertSliceOp};
+  };
+
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  scf::LoopNest loops = scf::buildLoopNest(
+      rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
+      {oneStep, oneStep}, {retValue}, buildBody);
+  return loops.results[0];
+}
+
+/// This function transforms the input. The data layout of the input is NHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+/// After the transformation, we get
+///
+/// scf.for %h = 0 to tileH step 1
+///   scf.for %w = 0 to tileW step 1
+///     scf.for %n = 0 to N step 1
+///       scf.for %c = 0 to C step 1
+///         %extracted = extract %extracted<alphaH x alphaW> from
+///                              %input<N x H x W x C>
+///                              at [%n, (%h x m), (%w x m), %c]
+///         %ret = linalg.matmul BT, %extracted
+///         %ret = linalg.matmul %ret, B
+///         %inserted = insert %ret<alphaH x alphaW> into
+///                            %output<alphaH x alphaW x tileH x tileW x N x C>
+///                            at [0, 0, %h, %w, %n, %c]
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     Value retValue, int64_t m, int64_t r,
+                     bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to BT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BTMatrices = {
+          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+      };
+
+  // Map from (m, r) to B transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BMatrices = {
+          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+      };
+
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  auto valueType = cast<ShapedType>(retValue.getType());
+  auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
+  int64_t tileH = valueShape[2];
+  int64_t tileW = valueShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if ((inputH != (tileH * m) + (r - 1)) && inputH != 1)
+    return Value();
+  if ((inputW != (tileW * m) + (r - 1)) && inputW != 1)
+    return Value();
+
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value tileHIter = ivs[0];
+    Value tileWIter = ivs[1];
+    Value NIter = ivs[2];
+    Value CIter = ivs[3];
+
+    auto context = builder.getContext();
+    auto affineMap =
+        AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+    Value heightOffset =
+        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
+    Value widthOffset =
+        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+
+    // Extract (H, W) from (N, H, W, C).
+    auto extractInput =
+        extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
+                            widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
+                            /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    int64_t retCols = 1;
+    Value matmulRetValue = extractInput;
+    if (leftTransform) {
+      // Get constant transform matrix BT.
+      auto it = BTMatrices.find(key);
+      if (it == BTMatrices.end())
+        return {};
+      const TransformMatrix &BTMatrix = it->second;
+
+      retRows = BTMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value BT =
+          create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+      // Multiply BT x d.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix B.
+      auto it = BMatrices.find(key);
+      if (it == BMatrices.end())
+        return {};
+      const TransformMatrix &BMatrix = it->second;
+
+      retCols = BMatrix.cols;
+      auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+      Value B =
+          create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+      // Multiply v = (BT x d) x B.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Insert (H, W) to (H, W, tileH, tileW, N, C).
+    auto combinedVal = insert2DDataTo6D(
+        builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
+        CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
+        /*heightIdx=*/0, /*widthIdx=*/1);
+
+    return {combinedVal};
+  };
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
+  auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  scf::LoopNest loops = scf::buildLoopNest(
+      rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
+      {tileHBound, tileWBound, nUpperBound, cUpperBound},
+      {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
+  return loops.results[0];
+}
+
 /// This function generates linalg.batch_matmul to multiply input with filter.
 /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
 /// tileH x tileW x H x W data as the 1-dimensional data array. That is to
@@ -107,6 +663,185 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
   return expandOutput;
 }
 
+/// This function transforms the output. The data layout of the output is HWNF.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+/// After the transformation, we get
+///
+/// scf.for %h = 0 to tileH step 1
+///   scf.for %w = 0 to tileW step 1
+///     scf.for %n = 0 to N step 1
+///       scf.for %f = 0 to F step 1
+///         %extracted = extract %extracted<alphaH x alphaW> from
+///                              %input<alphaH x alphaW x tileH x tileW x N x F>
+///                              at [0, 0, %h, %w, %n, %f]
+///         %ret = linalg.matmul AT, %extracted
+///         %ret = linalg.matmul %ret, A
+///         %inserted = insert %ret<alphaH x alphaW> into
+///                            output<N x H x W x F>
+///                            at [%n, (%h x m), (%w x m), %f]
+Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
+                      Value output, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to AT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      ATMatrices = {
+          {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+          {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+          {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+      };
+
+  // Map from (m, r) to A transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      AMatrices = {
+          {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+          {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+          {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+      };
+
+  auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
+  auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  int64_t valueN = valueShape[4];
+  int64_t valueF = valueShape[5];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (valueH != alphaH && valueH != 1)
+    return Value();
+  if (valueW != alphaW && valueW != 1)
+    return Value();
+
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value tileHIter = ivs[0];
+    Value tileWIter = ivs[1];
+    Value NIter = ivs[2];
+    Value FIter = ivs[3];
+
+    // Extract (H, W) from (H, W, tileH, tileW, N, F).
+    auto extractValue =
+        extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
+                            FIter, 2, 3, /*loopNorFIdx=*/4,
+                            /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    int64_t retCols = 1;
+    int64_t leftScalarFactor = 1;
+    int64_t rightScalarFactor = 1;
+    Value matmulRetValue = extractValue;
+    if (leftTransform) {
+      // Get constant transform matrix AT.
+      auto it = ATMatrices.find(key);
+      if (it == ATMatrices.end())
+        return {};
+      const TransformMatrix &ATMatrix = it->second;
+
+      leftScalarFactor = ATMatrix.scalarFactor;
+      retRows = ATMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
+      // Multiply AT x m.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix T.
+      auto it = AMatrices.find(key);
+      if (it == AMatrices.end())
+        return {};
+      const TransformMatrix &AMatrix = it->second;
+
+      rightScalarFactor = AMatrix.scalarFactor;
+      auto matmulType =
+          RankedTensorType::get({retRows, AMatrix.cols}, elementType);
+      retCols = AMatrix.cols;
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
+      // Multiply y = (AT x m) x A.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (leftScalarFactor * rightScalarFactor != 1) {
+      // Multiply scalar factor.
+      Value scalarFactor = builder.create<arith::ConstantOp>(
+          loc,
+          FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+      auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
+      SmallVector<AffineMap> affineMaps = {
+          AffineMap::get(2, 0, init.getContext()), identityAffineMap};
+      auto broadcastedScalar =
+          rewriter
+              .create<linalg::GenericOp>(
+                  loc, matmulType, ValueRange{scalarFactor}, ValueRange{init},
+                  affineMaps,
+                  llvm::ArrayRef<utils::IteratorType>{
+                      utils::IteratorType::parallel,
+                      utils::IteratorType::parallel},
+                  [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                      ValueRange args) {
+                    nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+                  })
+              .getResult(0);
+
+      matmulRetValue = builder
+                           .create<linalg::MulOp>(
+                               loc, matmulType,
+                               ValueRange{broadcastedScalar, matmulRetValue},
+                               ValueRange{init})
+                           .getResult(0);
+    }
+
+    auto context = builder.getContext();
+    auto affineMap =
+        AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+    Value heightOffset =
+        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
+    Value widthOffset =
+        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+
+    // Insert (H, W) to (N, H, W, F).
+    Value combinedVal =
+        insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
+                         heightOffset, widthOffset, retRows, retCols,
+                         /*loopNorFIdx=*/0,
+                         /*loopCorFIdx=*/3, /*heightIdx=*/1,
+                         /*widthIdx=*/2);
+
+    return {combinedVal};
+  };
+
+  int64_t tilwH = valueShape[2];
+  int64_t tileW = valueShape[3];
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
+  auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  scf::LoopNest loops = scf::buildLoopNest(
+      rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
+      {tileHBound, tileWBound, nUpperBound, fUpperBound},
+      {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
+  return loops.results[0];
+}
+
 /// Create an empty tensor with alignedType and insert the value into the
 /// created empty tensor with aligned size.
 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
@@ -292,6 +1027,120 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   return transformedOutput.getDefiningOp();
 }
 
+/// A helper function to decompose linalg.winograd_filter_transform.
+FailureOr<Operation *>
+decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradFilterTransformOp op) {
+  Location loc = op.getLoc();
+  Value filter = op.getFilter();
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  Value transformedFilter =
+      filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedFilter)
+    return failure();
+
+  rewriter.replaceOp(op, transformedFilter);
+
+  return transformedFilter.getDefiningOp();
+}
+
+/// A helper function to decompose linalg.winograd_input_transform.
+FailureOr<Operation *>
+decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
+                                      linalg::WinogradInputTransformOp op) {
+  Location loc = op.getLoc();
+  Value input = op.getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = inputH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = inputW != 1;
+  Value transformedInput =
+      inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
+                     op.getR(), leftTransform, rightTransform);
+  if (!transformedInput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedInput);
+
+  return transformedInput.getDefiningOp();
+}
+
+/// A helper function to decompose linalg.winograd_output_transform.
+FailureOr<Operation *>
+decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradOutputTransformOp op) {
+  Location loc = op.getLoc();
+  Value value = op.getValue();
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = valueH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = valueW != 1;
+  Value transformedOutput =
+      outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedOutput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+/// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
+class DecomposeWinogradFilterTransform final
+    : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    return decomposeWinogradFilterTransformHelper(rewriter, op);
+  }
+};
+
+/// A rewrite pattern to decompose linalg.winograd_input_transform operations.
+class DecomposeWinogradInputTransform final
+    : public OpRewritePattern<linalg::WinogradInputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    return decomposeWinogradInputTransformHelper(rewriter, op);
+  }
+};
+
+/// A rewrite pattern to decompose linalg.winograd_output_transform operations.
+class DecomposeWinogradOutputTransform final
+    : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    return decomposeWinogradOutputTransformHelper(rewriter, op);
+  }
+};
+
 /// A rewrite pattern for Winograd Conv2D algorithm.
 class WinogradConv2DNhwcFhwc final
     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
@@ -328,5 +1177,12 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
   patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
 }
 
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns
+      .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
+              DecomposeWinogradOutputTransform>(context);
+}
+
 } // end namespace linalg
 } // end namespace mlir

diff  --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
new file mode 100644
index 0000000000000..095a6636b68dc
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -0,0 +1,120 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-decompose-winograd-ops | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %2 = tensor.empty() : tensor<6x6x5x2xf32>
+  %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+  %4 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+  %5 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%4 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+  %6 = tensor.empty() : tensor<36x18x2xf32>
+  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+  %expanded = tensor.expand_shape %7 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+  %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %extracted_slice = tensor.extract_slice %8[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+  return %extracted_slice : tensor<2x9x9x2xf32>
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:   %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:   %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:   %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:   %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:       %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], %[[C0]], %[[C0]], %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<3x3xf32>
+// CHECK-NEXT:       %[[S8:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:       %[[S9:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S8]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:       %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:       %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[ARG6]][%[[C0]], %[[C0]], %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S7]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT:   ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
+// CHECK-NEXT:     tensor.yield %[[CST_6]] : f32
+// CHECK-NEXT:   } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:       %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:         %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:           %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT:           %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S10]], %[[S11]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S13:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S14:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S15:.*]] = linalg.matmul ins(%[[S13]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG10]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S7]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_7]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK-NEXT:   %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT:   ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
+// CHECK-NEXT:     tensor.yield %[[CST_6]] : f32
+// CHECK-NEXT:   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+// CHECK-NEXT:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:       %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:         %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:           %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:           %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S14]] : tensor<4x4xf32>) {
+// CHECK-NEXT:           ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:             linalg.yield %[[IN]] : f32
+// CHECK-NEXT:           } -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S16:.*]] = linalg.mul ins(%[[S15]], %[[S13]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT:           %[[S18:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG10]][%[[ARG7]], %[[S17]], %[[S18]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S9]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       scf.yield %[[S8]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S7]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 12cb46a5968f1..5899f56da7345 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -127,6 +127,9 @@ struct TestLinalgTransforms
       *this, "test-winograd-conv2d",
       llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
       llvm::cl::init(false)};
+  Option<bool> testDecomposeWinogradOps{
+      *this, "test-decompose-winograd-ops",
+      llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
 };
 } // namespace
 
@@ -218,6 +221,12 @@ static void applyWinogradConv2D(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateDecomposeWinogradOpsPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -244,6 +253,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnnecessaryInputs(getOperation());
   if (testWinogradConv2D)
     return applyWinogradConv2D(getOperation());
+  if (testDecomposeWinogradOps)
+    return applyDecomposeWinogradOps(getOperation());
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list