[Mlir-commits] [mlir] d16f42d - [mlir][linalg] Constrain the parameters m, r in Winograd ops (#144657)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 25 06:02:10 PDT 2025


Author: Hsiangkai Wang
Date: 2025-06-25T14:02:07+01:00
New Revision: d16f42d1e227169cace475e7715779b8dadb79c2

URL: https://github.com/llvm/llvm-project/commit/d16f42d1e227169cace475e7715779b8dadb79c2
DIFF: https://github.com/llvm/llvm-project/commit/d16f42d1e227169cace475e7715779b8dadb79c2.diff

LOG: [mlir][linalg] Constrain the parameters m, r in Winograd ops (#144657)

We only support fixed set of minimum filtering algorithm for Winograd
Conv2D decomposition. Instead of letting users specify any integer,
define a fixed set of enumeration values for the parameters of minimum
filtering algorithm.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
    mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
    mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
    mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
    mlir/test/Dialect/Linalg/winograd-conv2d.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 57bf6305a469d..4f5fea107f07b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -100,6 +100,20 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
 
 #include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"
 
+namespace mlir {
+namespace linalg {
+
+/// Converts the given `m` and `r` parameters to a WinogradConv2DFmr enumeration
+/// value.
+std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r);
+
+/// Converts the given WinogradConv2DFmr enumeration value to a pair of
+/// m and r parameters.
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
+
+} // namespace linalg
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // Linalg Attributes
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index ce68afe471fe8..1109db973f522 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -122,4 +122,19 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
   let cppNamespace = "::mlir::linalg";
 }
 
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+def WinogradConv2DFmr : I32EnumAttr<"WinogradConv2DFmr", "", [
+      I32EnumAttrCase<"F_2_3", 0>,
+      I32EnumAttrCase<"F_4_3", 1>,
+      I32EnumAttrCase<"F_2_5", 2>,
+]>{
+  let cppNamespace = "mlir::linalg";
+}
+
 #endif // LINALG_ENUMS

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 1b48bf5fcb237..7ff44c2e1d2ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -183,15 +183,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
                        TensorRankOf<[AnyType], [4]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [4]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $filter `:` type($filter) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
@@ -254,15 +252,13 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
                        TensorRankOf<[AnyType], [6]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [6]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $input `:` type($input) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
@@ -343,15 +339,13 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
                        TensorRankOf<[AnyType], [4]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [4]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $value `:` type($value) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 9d6ce653e285c..d64f94a49f781 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -9,6 +9,7 @@
 #ifndef LINALG_TRANSFORM_OPS
 #define LINALG_TRANSFORM_OPS
 
+include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
 include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
@@ -2902,8 +2903,7 @@ def WinogradConv2DOp : Op<Transform_Dialect,
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       I64Attr:$m,
-                       I64Attr:$r);
+                       WinogradConv2DFmr:$fmr);
   let results = (outs TransformHandleTypeInterface:$transformed);
 
   let assemblyFormat =

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 189438e9ad528..2b4855f49695c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -37,6 +37,7 @@ class BufferizationState;
 namespace linalg {
 
 class LinalgOp;
+enum class WinogradConv2DFmr : uint32_t;
 
 //===----------------------------------------------------------------------===//
 // Utils.
@@ -1426,8 +1427,8 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
 /// F(m x m, r x r). m is the dimension size of output and r is the dimension
 /// size of filter.
 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
-                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
-                                      int64_t r);
+                                      linalg::Conv2DNhwcFhwcOp op,
+                                      WinogradConv2DFmr fmr);
 
 /// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
 /// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
@@ -1968,8 +1969,8 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
 /// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
-                                    int64_t r);
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
+                                    WinogradConv2DFmr fmr);
 
 /// Patterns to decompose Winograd operators.
 void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbb2403eddbd..b2639edb0d0f5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2989,8 +2989,9 @@ LogicalResult WinogradFilterTransformOp::verify() {
   ArrayRef<int64_t> filterShape = filterType.getShape();
   int64_t filterH = filterShape[getFilterHDim()];
   int64_t filterW = filterShape[getFilterWDim()];
-  int64_t r = getR();
-  int64_t m = getM();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   if (filterH != r && filterH != 1)
     return emitOpError("expect filter height either equals to r or 1");
@@ -3046,8 +3047,9 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
   ArrayRef<int64_t> filterShape = filterType.getShape();
   int64_t filterH = filterShape[getFilterHDim()];
   int64_t filterW = filterShape[getFilterWDim()];
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t alpha = m + r - 1;
   int64_t alphaH = filterH != 1 ? alpha : 1;
   int64_t alphaW = filterW != 1 ? alpha : 1;
@@ -3124,8 +3126,9 @@ LogicalResult WinogradInputTransformOp::verify() {
   ArrayRef<int64_t> inputShape = inputType.getShape();
   int64_t inputH = inputShape[getInputHDim()];
   int64_t inputW = inputShape[getInputWDim()];
-  int m = getM();
-  int r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t tileSize = m + r - 1;
 
   auto outputType = cast<ShapedType>(getOutput().getType());
@@ -3194,8 +3197,9 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
   int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
   int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
 
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t alpha = m + r - 1;
   int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
   int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
@@ -3224,8 +3228,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   ShapedType outputType = getOutputOperandType();
   ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -3303,8 +3308,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
   int64_t valueW = valueShape[getValueAlphaWDim()];
   int64_t valueTileH = valueShape[getValueTileHDim()];
   int64_t valueTileW = valueShape[getValueTileWDim()];
-  int m = getM();
-  int r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   bool leftTransform = valueH != 1;
   bool rightTransform = valueW != 1;
 
@@ -3365,7 +3371,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
-  int64_t m = getM();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
@@ -3623,6 +3631,27 @@ verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
 namespace mlir {
 namespace linalg {
 
+std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
+  if (m == 2 && r == 3)
+    return WinogradConv2DFmr::F_2_3;
+  if (m == 4 && r == 3)
+    return WinogradConv2DFmr::F_4_3;
+  if (m == 2 && r == 5)
+    return WinogradConv2DFmr::F_2_5;
+  return std::nullopt;
+}
+
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
+  switch (fmr) {
+  case WinogradConv2DFmr::F_2_3:
+    return {2, 3};
+  case WinogradConv2DFmr::F_4_3:
+    return {4, 3};
+  case WinogradConv2DFmr::F_2_5:
+    return {2, 5};
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // MatMulOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2b78e31558ea2..8571d641e26d1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4250,7 +4250,7 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
   bool supported = TypeSwitch<Operation *, bool>(target)
                        .Case([&](linalg::Conv2DNhwcFhwcOp op) {
                          maybeTransformed =
-                             winogradConv2D(rewriter, op, getM(), getR());
+                             winogradConv2D(rewriter, op, getFmr());
                          return true;
                        })
                        .Default([&](Operation *op) { return false; });

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index e4221d4748415..c61b23c63dc56 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -38,6 +39,15 @@ namespace {
 ///
 /// The following tables define these constant transformation matrices for
 /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+///
+/// To add more transformation matrices, we need to add the following
+/// items:
+/// 1. Add the constant transformation matrix to the corresponding
+///   G, GT, BT, B, AT, or A array.
+/// 2. Add the corresponding TransformMatrix to the GMatrices, GTMatrices,
+///   BTMatrices, BMatrices, ATMatrices, or AMatrices map.
+/// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
+///
 constexpr float G_2x2_3x3[] = {
    -1,     0,   0,
  1./2, -1./2, 1./2,
@@ -176,19 +186,6 @@ constexpr float A_2x2_5x5[] = {
 };
 // clang-format on
 
-using TransformMapKeyTy = std::pair<int, int>;
-
-/// We use F(m, r) to define the size of minimal filtering algorithms.
-/// m is the output dimension and r is the filter dimension. We can get
-/// the input dimension, alpha, from the formula, alpha = m + r - 1.
-///
-/// For example, when m = 2 and r = 3, we know its input size is 4.
-/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
-/// 2x2 output result.
-constexpr TransformMapKeyTy F_2_3{2, 3};
-constexpr TransformMapKeyTy F_4_3{4, 3};
-constexpr TransformMapKeyTy F_2_5{2, 5};
-
 /// Structure to keep information of constant transform matrices.
 struct TransformMatrix {
   TransformMatrix(const float *table, int64_t rows, int64_t cols,
@@ -344,22 +341,22 @@ Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
 ///     %ret = linalg.matmul %ret, GT
 ///     %inserted = insert %ret into filter<h x w x c x f>
 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
-                      Value retValue, int64_t m, int64_t r,
+                      Value retValue, WinogradConv2DFmr fmr,
                       bool leftTransform = true, bool rightTransform = true) {
   // Map from (m, r) to G transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       GMatrices = {
-          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
-          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
-          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
       };
 
   // Map from (m, r) to GT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       GTMatrices = {
-          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
-          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
-          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
       };
 
   auto filterType = cast<ShapedType>(filter.getType());
@@ -370,6 +367,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
   int64_t filterW = filterShape[2];
   int64_t filterC = filterShape[3];
 
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   if (filterH != r && filterH != 1)
     return Value();
   if (filterW != r && filterW != 1)
@@ -387,14 +386,13 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
                             zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
 
-    TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
     Value matmulRetValue = extractFilter;
     Value zero = builder.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix G.
-      auto it = GMatrices.find(key);
+      auto it = GMatrices.find(fmr);
       if (it == GMatrices.end())
         return {};
       const TransformMatrix &GMatrix = it->second;
@@ -416,7 +414,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 
     if (rightTransform) {
       // Get constant transform matrix GT.
-      auto it = GTMatrices.find(key);
+      auto it = GTMatrices.find(fmr);
       if (it == GTMatrices.end())
         return {};
       const TransformMatrix &GTMatrix = it->second;
@@ -476,24 +474,26 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 ///                            %output<alphaH x alphaW x tileH x tileW x N x C>
 ///                            at [0, 0, %h, %w, %n, %c]
 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
-                     Value retValue, int64_t m, int64_t r,
+                     Value retValue, WinogradConv2DFmr fmr,
                      bool leftTransform = true, bool rightTransform = true) {
   // Map from (m, r) to BT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       BTMatrices = {
-          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
-          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
-          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
       };
 
   // Map from (m, r) to B transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       BMatrices = {
-          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
-          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
-          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
       };
 
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   auto inputType = cast<ShapedType>(input.getType());
   Type elementType = inputType.getElementType();
   auto inputShape = inputType.getShape(); // N, H, W, C
@@ -529,7 +529,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
                             widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
 
-    TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
     int64_t retCols = 1;
     Value matmulRetValue = extractInput;
@@ -537,7 +536,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
         loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix BT.
-      auto it = BTMatrices.find(key);
+      auto it = BTMatrices.find(fmr);
       if (it == BTMatrices.end())
         return {};
       const TransformMatrix &BTMatrix = it->second;
@@ -560,7 +559,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
 
     if (rightTransform) {
       // Get constant transform matrix B.
-      auto it = BMatrices.find(key);
+      auto it = BMatrices.find(fmr);
       if (it == BMatrices.end())
         return {};
       const TransformMatrix &BMatrix = it->second;
@@ -696,24 +695,26 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
 ///                            output<N x H x W x F>
 ///                            at [%n, (%h x m), (%w x m), %f]
 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
-                      Value output, int64_t m, int64_t r,
+                      Value output, WinogradConv2DFmr fmr,
                       bool leftTransform = true, bool rightTransform = true) {
   // Map from (m, r) to AT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       ATMatrices = {
-          {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
-          {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
-          {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
       };
 
   // Map from (m, r) to A transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       AMatrices = {
-          {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
-          {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
-          {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
       };
 
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   auto valueType = cast<ShapedType>(value.getType());
   Type elementType = valueType.getElementType();
   auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
@@ -743,9 +744,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
                             FIter, 2, 3, /*loopNorFIdx=*/4,
                             /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
 
-    const TransformMapKeyTy key = {m, r};
-    const TransformMatrix &AMatrix = AMatrices.at(key);
-    const TransformMatrix &ATMatrix = ATMatrices.at(key);
+    const TransformMatrix &AMatrix = AMatrices.at(fmr);
+    const TransformMatrix &ATMatrix = ATMatrices.at(fmr);
     int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
                            (leftTransform ? ATMatrix.scalarFactor : 1);
     int64_t retCols = rightTransform ? AMatrix.cols : 1;
@@ -903,7 +903,7 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) {
 /// linalg.winograd_*_transform ops.
 static FailureOr<Operation *>
 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
-                     int64_t m, int64_t r) {
+                     WinogradConv2DFmr fmr) {
   if (!convOp.hasPureTensorSemantics())
     return rewriter.notifyMatchFailure(
         convOp, "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
@@ -946,6 +946,8 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   int64_t outputW = outputShape[2];
   int64_t outputF = outputShape[3];
 
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
   bool isSupportedFilter = false;
   if (filterH == filterW && filterH == r)
@@ -959,17 +961,6 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
     return rewriter.notifyMatchFailure(
         convOp, "only support filter (r x r), (r x 1) or (1 x r)");
 
-  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
-  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
-      F_2_3, F_4_3, F_2_5};
-
-  TransformMapKeyTy key = {m, r};
-  auto it = llvm::find(validConfigs, key);
-  // If we cannot find the constant transformation matrix, it means we do
-  // not support this configuration yet.
-  if (it == validConfigs.end())
-    return failure();
-
   // All the criterias are satisfied. We can do Winograd Conv2D.
   Location loc = convOp.getLoc();
 
@@ -993,7 +984,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
                                                     filterElementType);
   auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
-      loc, retType, filter, retValue, m, r);
+      loc, retType, filter, retValue, fmr);
 
   // --- Create operation for input transform ---
 
@@ -1012,7 +1003,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
                                               inputElementType);
   auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
-      loc, retType, input, retValue, m, r);
+      loc, retType, input, retValue, fmr);
 
   Type outputElementType = outputType.getElementType();
   Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
@@ -1035,7 +1026,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
   }
 
   Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
-      loc, outputType, matmulRet, output, m, r);
+      loc, outputType, matmulRet, output, fmr);
 
   // When output size is not aligned with output tile size, extract the
   // value from the padded buffer.
@@ -1067,8 +1058,8 @@ decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
   // For F(1 x m, 1 x r), we only need to do right side transform.
   bool rightTransform = filterW != 1;
   Value transformedFilter =
-      filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
-                      op.getR(), leftTransform, rightTransform);
+      filterTransform(rewriter, loc, filter, op.getOutput(), op.getFmr(),
+                      leftTransform, rightTransform);
   if (!transformedFilter)
     return failure();
 
@@ -1094,8 +1085,8 @@ decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
   // For F(1 x m, 1 x r), we only need to do right side transform.
   bool rightTransform = outputW != 1;
   Value transformedInput =
-      inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
-                     op.getR(), leftTransform, rightTransform);
+      inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getFmr(),
+                     leftTransform, rightTransform);
   if (!transformedInput)
     return failure();
 
@@ -1120,8 +1111,8 @@ decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
   // For F(1 x m, 1 x r), we only need to do right side transform.
   bool rightTransform = valueW != 1;
   Value transformedOutput =
-      outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
-                      op.getR(), leftTransform, rightTransform);
+      outputTransform(rewriter, loc, value, op.getOutput(), op.getFmr(),
+                      leftTransform, rightTransform);
   if (!transformedOutput)
     return failure();
 
@@ -1171,28 +1162,28 @@ class WinogradConv2DNhwcFhwc final
     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
-  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
-      : OpRewritePattern(context), m(m), r(r) {}
+  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, WinogradConv2DFmr fmr)
+      : OpRewritePattern(context), fmr(fmr) {}
 
   LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
                                 PatternRewriter &rewriter) const override {
-    if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
+    if (failed(winogradConv2DHelper(rewriter, convOp, fmr)))
       return failure();
 
     return success();
   }
 
 private:
-  int64_t m;
-  int64_t r;
+  WinogradConv2DFmr fmr;
 };
+
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
-                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
-                                      int64_t r) {
-  return winogradConv2DHelper(rewriter, op, m, r);
+                                      linalg::Conv2DNhwcFhwcOp op,
+                                      linalg::WinogradConv2DFmr fmr) {
+  return winogradConv2DHelper(rewriter, op, fmr);
 }
 
 FailureOr<Operation *>
@@ -1213,11 +1204,11 @@ decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
   return decomposeWinogradOutputTransformHelper(rewriter, op);
 }
 
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
-                                    int64_t r) {
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
+                                    WinogradConv2DFmr fmr) {
   MLIRContext *context = patterns.getContext();
   // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
-  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
+  patterns.insert<WinogradConv2DNhwcFhwc>(context, fmr);
 }
 
 void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index ca40301f04fa1..cbc863699ba9e 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1165,7 +1165,7 @@ func.func @mixed_semantics(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>, %c: memref<
 
 func.func @winograd_filter_transform_height(%arg0: tensor<2x4x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
   // expected-error @+1 {{expect filter height either equals to r or 1}}
-  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x4x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x4x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
   return %0 : tensor<6x6x5x2xf32>
 }
 
@@ -1173,7 +1173,7 @@ func.func @winograd_filter_transform_height(%arg0: tensor<2x4x3x5xf32>, %arg1: t
 
 func.func @winograd_filter_transform_width(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
   // expected-error @+1 {{expect filter width either equals to r or 1}}
-  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x4x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x3x4x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
   return %0 : tensor<6x6x5x2xf32>
 }
 
@@ -1181,7 +1181,7 @@ func.func @winograd_filter_transform_width(%arg0: tensor<2x3x4x5xf32>, %arg1: te
 
 func.func @winograd_filter_transform(%arg0: tensor<2x1x1x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
   // expected-error @+1 {{expect either filter height or width equals to r}}
-  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x1x1x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x1x1x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
   return %0 : tensor<6x6x5x2xf32>
 }
 
@@ -1189,7 +1189,7 @@ func.func @winograd_filter_transform(%arg0: tensor<2x1x1x5xf32>, %arg1: tensor<6
 
 func.func @winograd_filter_dyn(%arg0: tensor<?x3x3x?xf32>, %arg1: tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32> {
   // expected-error @+1 {{the output shape is not expected}}
-  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32>
+  %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32>
   return %0 : tensor<6x5x?x?xf32>
 }
 
@@ -1197,7 +1197,7 @@ func.func @winograd_filter_dyn(%arg0: tensor<?x3x3x?xf32>, %arg1: tensor<6x5x?x?
 
 func.func @winograd_input_transform_height(%arg0: tensor<2x13x14x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> {
   // expected-error @+1 {{the output shape is not expected}}
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x13x14x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x13x14x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
   return %0 : tensor<6x6x3x3x2x5xf32>
 }
 
@@ -1205,7 +1205,7 @@ func.func @winograd_input_transform_height(%arg0: tensor<2x13x14x5xf32>, %arg1:
 
 func.func @winograd_input_transform_width(%arg0: tensor<2x14x13x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> {
   // expected-error @+1 {{the output shape is not expected}}
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x13x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x14x13x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
   return %0 : tensor<6x6x3x3x2x5xf32>
 }
 
@@ -1213,7 +1213,7 @@ func.func @winograd_input_transform_width(%arg0: tensor<2x14x13x5xf32>, %arg1: t
 
 func.func @winograd_input_transform_output_tileH(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32> {
   // expected-error @+1 {{the output shape is not expected}}
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32>
   return %0 : tensor<6x6x2x3x2x5xf32>
 }
 
@@ -1221,7 +1221,7 @@ func.func @winograd_input_transform_output_tileH(%arg0: tensor<2x14x14x5xf32>, %
 
 func.func @winograd_input_transform_output_tileW(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32> {
   // expected-error @+1 {{the output shape is not expected}}
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32>
   return %0 : tensor<6x6x3x2x2x5xf32>
 }
 
@@ -1229,7 +1229,7 @@ func.func @winograd_input_transform_output_tileW(%arg0: tensor<2x14x14x5xf32>, %
 
 func.func @winograd_input_transform_output_height(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32> {
   // expected-error @+1 {{the output shape is not expected}}
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32>
   return %0 : tensor<5x6x3x3x2x5xf32>
 }
 
@@ -1237,7 +1237,7 @@ func.func @winograd_input_transform_output_height(%arg0: tensor<2x14x14x5xf32>,
 
 func.func @winograd_input_transform_output_width(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32> {
   // expected-error @+1 {{the output shape is not expected}}
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32>
   return %0 : tensor<6x5x3x3x2x5xf32>
 }
 
@@ -1245,7 +1245,7 @@ func.func @winograd_input_transform_output_width(%arg0: tensor<2x14x14x5xf32>, %
 
 func.func @winograd_input_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32> {
   // expected-error @+1 {{the output shape is not expected}}
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32>
   return %0 : tensor<6x5x?x?x?x?xf32>
 }
 
@@ -1253,7 +1253,7 @@ func.func @winograd_input_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<6x5x?x?x
 
 func.func @winograd_output_transform_input_height(%arg0: tensor<5x6x3x3x2x2xf32>, %arg1: tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> {
   // expected-error @+1 {{expect input height equals to input tile size}}
-  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<5x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<5x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
   return %0 : tensor<2x12x12x2xf32>
 }
 
@@ -1261,7 +1261,7 @@ func.func @winograd_output_transform_input_height(%arg0: tensor<5x6x3x3x2x2xf32>
 
 func.func @winograd_output_transform_input_width(%arg0: tensor<6x5x3x3x2x2xf32>, %arg1: tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> {
   // expected-error @+1 {{expect input width equals to input tile size}}
-  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x5x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x5x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
   return %0 : tensor<2x12x12x2xf32>
 }
 
@@ -1269,7 +1269,7 @@ func.func @winograd_output_transform_input_width(%arg0: tensor<6x5x3x3x2x2xf32>,
 
 func.func @winograd_output_transform_output_height(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32> {
   // expected-error @+1 {{the output shape is not expected}}
-  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32>
+  %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32>
   return %0 : tensor<2x11x12x2xf32>
 }
 
@@ -1277,7 +1277,7 @@ func.func @winograd_output_transform_output_height(%arg0: tensor<6x6x3x3x2x2xf32
 
 func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32> {
   // expected-error @+1 {{the output shape is not expected}}
-  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
+  %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
   return %0 : tensor<2x12x11x2xf32>
 }
 

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index dc556761b09e5..4edbc6eda3eae 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -630,52 +630,52 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
 
 func.func @winograd(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
   %0 = tensor.empty() : tensor<6x6x5x2xf32>
-  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
   %2 = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%2 : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+  %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%2 : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
   %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
   %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
   %4 = tensor.empty() : tensor<36x2x2xf32>
   %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
   %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  %6 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
   return %6 : tensor<2x4x4x2xf32>
 }
 
 // CHECK-LABEL: func @winograd
-// CHECK:         linalg.winograd_filter_transform m(4) r(3)
-// CHECK:         linalg.winograd_input_transform m(4) r(3)
-// CHECK:         linalg.winograd_output_transform m(4) r(3)
+// CHECK:         linalg.winograd_filter_transform fmr(F_4_3)
+// CHECK:         linalg.winograd_input_transform fmr(F_4_3)
+// CHECK:         linalg.winograd_output_transform fmr(F_4_3)
 
 // -----
 
 func.func @winograd_filter_dyn(%arg0: tensor<?x3x3x?xf32>, %arg1: tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32> {
-  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32>
+  %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32>
   return %0 : tensor<6x6x?x?xf32>
 }
 
 // CHECK-LABEL: func @winograd_filter_dyn
-// CHECK:         linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32>
+// CHECK:         linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32>
 
 // -----
 
 func.func @winograd_input_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32> {
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32>
   return %0 : tensor<6x6x?x?x?x?xf32>
 }
 
 // CHECK-LABEL: func @winograd_input_dyn
-// CHECK:         linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32>
+// CHECK:         linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32>
 
 // -----
 
 func.func @winograd_output_dyn(%arg0: tensor<6x6x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
-  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
   return %0 : tensor<?x?x?x?xf32>
 }
 
 // CHECK-LABEL: func @winograd_output_dyn
-// CHECK:         linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK:         linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
 
 // -----
 

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index cdc4b8a72a276..445ded4bfcafb 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -2,15 +2,15 @@
 
 func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
   %0 = tensor.empty() : tensor<6x6x5x2xf32>
-  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
   %2 = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
   %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
   %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
   %4 = tensor.empty() : tensor<36x8x2xf32>
   %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
   %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  %6 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
   return %6 : tensor<2x8x8x2xf32>
 }
 
@@ -123,13 +123,13 @@ module attributes {transform.with_named_sequence} {
 func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<6x6x5x2xf32>
-  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
   %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
   ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
     tensor.yield %cst : f32
   } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
   %2 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-  %3 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+  %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
   %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
   %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
   %4 = tensor.empty() : tensor<36x18x2xf32>
@@ -140,7 +140,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
   ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
     tensor.yield %cst : f32
   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
   %extracted_slice = tensor.extract_slice %7[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
   return %extracted_slice : tensor<2x9x9x2xf32>
 }
@@ -259,16 +259,16 @@ module attributes {transform.with_named_sequence} {
 func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<6x1x5x2xf32>
-  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+  %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
   %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+  %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
   %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
   %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
   %4 = tensor.empty() : tensor<6x2x2xf32>
   %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
   %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
   %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
   return %7 : tensor<2x4x1x2xf32>
 }
 
@@ -350,16 +350,16 @@ module attributes {transform.with_named_sequence} {
 func.func @conv2d_mx1_rx1_2(%arg0: tensor<2x6x2x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<6x1x5x2xf32>
-  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+  %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
   %2 = tensor.empty() : tensor<6x1x1x2x2x5xf32>
-  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
+  %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
   %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
   %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x2x2x5xf32> into tensor<6x4x5xf32>
   %4 = tensor.empty() : tensor<6x4x2xf32>
   %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
   %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
   %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2] : tensor<6x4x2xf32> into tensor<6x1x1x2x2x2xf32>
-  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
+  %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
   return %7 : tensor<2x4x2x2xf32>
 }
 

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
index fc6424fd4c812..beb8d0b125738 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -transform-interpreter --split-input-file | FileCheck %s
 
 func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
-  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
   return %0 : tensor<6x6x5x2xf32>
 }
 
@@ -25,13 +25,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK:    %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
 // CHECK:      %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
-// CHECK:      %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
+// CHECK:      %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
 // CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S3]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
 
 // -----
 
 func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
-  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
   return %0 : tensor<6x6x5x2xf32>
 }
 
@@ -58,12 +58,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       %[[S3:.*]] = affine.min #[[$MAP0]](%[[ARG4]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, %[[S3]]] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x?xf32>
 // CHECK:       %[[EXTRACTED_SLICE_3:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32>
-// CHECK:       %[[S4:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
+// CHECK:       %[[S4:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
 // CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S4]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x?x1xf32> into tensor<6x6x5x2xf32>
 // -----
 
 func.func @tile_winograd_filter(%arg0: tensor<2x3x1x5xf32>, %arg1: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
-  %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x1x5xf32>) outs(%arg1 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+  %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x3x1x5xf32>) outs(%arg1 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
   return %0 : tensor<6x1x5x2xf32>
 }
 
@@ -87,13 +87,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
 // CHECK:       %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
-// CHECK:       %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
+// CHECK:       %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
 // CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S3]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
 
 // -----
 
 func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
   return %0 : tensor<6x6x2x2x2x5xf32>
 }
 
@@ -123,13 +123,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK:   %[[S6:.*]] = affine.apply #[[$MAP1]]()
 // CHECK:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x5xf32>
 // CHECK:   %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK:   %[[S7:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK:   %[[S7:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
 // CHECK:   %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S7]] into %[[ARG5]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
 
 // -----
 
 func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
   return %0 : tensor<6x6x2x2x2x5xf32>
 }
 
@@ -167,13 +167,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK:         %[[S8:.*]] = affine.apply #[[$MAP1]]()
 // CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], %[[S8]], 1] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<1x?x?x1xf32>
 // CHECK:         %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32>
-// CHECK:         %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+// CHECK:         %[[S9:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
 // CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x2x2x2x5xf32>
 
 // -----
 
 func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
   return %0 : tensor<6x6x2x2x2x5xf32>
 }
 
@@ -213,13 +213,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK:         %[[S9:.*]] = affine.apply #[[$MAP2]]()
 // CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S6]], %[[S7]], %[[ARG8]]] [2, %[[S8]], %[[S9]], %[[S5]]] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x?xf32>
 // CHECK:         %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32>
-// CHECK:         %[[S10:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
+// CHECK:         %[[S10:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
 // CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x?xf32> into tensor<6x6x2x2x2x5xf32>
 
 // -----
 
 func.func @tile_winograd_input(%arg0: tensor<2x1x10x5xf32>, %arg1: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
-  %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x1x10x5xf32>) outs(%arg1 : tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32>
+  %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x1x10x5xf32>) outs(%arg1 : tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32>
   return %0 : tensor<1x6x1x2x2x5xf32>
 }
 
@@ -258,13 +258,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK:           %[[S8:.*]] = affine.apply #[[$MAP2]]()
 // CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
 // CHECK:           %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
-// CHECK:           %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
+// CHECK:           %[[S9:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
 // CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x1x1x1xf32> into tensor<1x6x1x2x2x5xf32>
 
 // -----
 
 func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x2x2xf32>, %arg1: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x2x2xf32>) outs(%arg1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x2x2x2x2xf32>) outs(%arg1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
   return %0 : tensor<2x8x8x2xf32>
 }
 
@@ -298,7 +298,7 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x3x5xf32>, %arg1: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
-  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x3x5xf32>) outs(%arg1 : tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32>
+  %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x2x2x3x5xf32>) outs(%arg1 : tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32>
   return %0 : tensor<3x8x8x5xf32>
 }
 
@@ -346,7 +346,7 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @tile_winograd_output(%arg0 : tensor<6x1x2x1x3x5xf32>, %arg1: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
-  %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x1x2x1x3x5xf32>) outs(%arg1 : tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32>
+  %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x1x2x1x3x5xf32>) outs(%arg1 : tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32>
   return %0 : tensor<3x8x1x5xf32>
 }
 
@@ -385,4 +385,4 @@ module attributes {transform.with_named_sequence} {
 // CHECK:           %[[S7:.*]] = affine.apply #[[$MAP2]]()
 // CHECK:           %[[S8:.*]] = affine.apply #[[$MAP2]]()
 // CHECK:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
-// CHECK:           %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>
+// CHECK:           %[[S9:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>

diff  --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
index 1de861e653005..e0ead54c956fc 100644
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s
 
 func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
   %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
@@ -8,16 +8,16 @@ func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    %1 = transform.structured.winograd_conv2d %0 { fmr = 1: i32 } : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
 
 // CHECK-LABEL: func.func @conv2d
-// CHECK: linalg.winograd_filter_transform m(4) r(3)
-// CHECK: linalg.winograd_input_transform m(4) r(3)
+// CHECK: linalg.winograd_filter_transform fmr(F_4_3)
+// CHECK: linalg.winograd_input_transform fmr(F_4_3)
 // CHECK: linalg.batch_matmul
-// CHECK: linalg.winograd_output_transform m(4) r(3)
+// CHECK: linalg.winograd_output_transform fmr(F_4_3)
 
 // -----
 
@@ -29,19 +29,19 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    %1 = transform.structured.winograd_conv2d %0 { fmr = 1: i32 } : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
 
 // CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK:       linalg.winograd_filter_transform m(4) r(3)
+// CHECK:       linalg.winograd_filter_transform fmr(F_4_3)
 // CHECK:       tensor.pad
 // CHECK-SAME:  low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK:       linalg.winograd_input_transform m(4) r(3)
+// CHECK:       linalg.winograd_input_transform fmr(F_4_3)
 // CHECK:       tensor.pad
 // CHECK-SAME:  low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK:       linalg.winograd_output_transform m(4) r(3)
+// CHECK:       linalg.winograd_output_transform fmr(F_4_3)
 
 // -----
 
@@ -54,7 +54,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error @+1 {{this operation is not supported to convert to Winograd Conv2D}}
-    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    %1 = transform.structured.winograd_conv2d %0 { fmr = 1: i32 } : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -70,7 +70,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error @+1 {{apply Winograd Conv2D failed}}
-    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    %1 = transform.structured.winograd_conv2d %0 { fmr = 1: i32 } : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -86,7 +86,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error @+1 {{apply Winograd Conv2D failed}}
-    %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+    %1 = transform.structured.winograd_conv2d %0 { fmr = 1: i32 } : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }

diff  --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 16d06a7473272..c7b0bd51308ba 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -3,13 +3,13 @@
 func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %2 = tensor.empty() : tensor<6x6x5x2xf32>
-  %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+  %3 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
   %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
   ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
     tensor.yield %cst : f32
   } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
   %4 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-  %5 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%4 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+  %5 = linalg.winograd_input_transform fmr(F_4_3) ins(%padded : tensor<2x14x14x5xf32>) outs(%4 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
   %collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
   %collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
   %6 = tensor.empty() : tensor<36x18x2xf32>
@@ -20,7 +20,7 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
   ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
     tensor.yield %cst : f32
   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-  %9 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %9 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
   %extracted_slice = tensor.extract_slice %9[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
   return %extracted_slice : tensor<2x9x9x2xf32>
 }

diff  --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
index 0040d81a2d24e..e80fa6b4af944 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -9,16 +9,16 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
 // CHECK-NEXT:  %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
 // CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
 // CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
 // CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
 // CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
 // CHECK-NEXT:  %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
 // CHECK-NEXT:  %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
 // CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT:  %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[S9:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
 // CHECK-NEXT:  return %[[S9]] : tensor<2x4x4x2xf32>
 // CHECK-NEXT: }
 
@@ -33,16 +33,16 @@ func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
 // CHECK-NEXT:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_2_5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
 // CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform fmr(F_2_5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
 // CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
 // CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
 // CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
 // CHECK-NEXT:   %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
 // CHECK-NEXT:   %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
 // CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S9:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[S9:.*]] = linalg.winograd_output_transform fmr(F_2_5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
 // CHECK-NEXT:   return %[[S9]] : tensor<2x2x2x2xf32>
 // CHECK-NEXT: }
 
@@ -57,16 +57,16 @@ func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
 // CHECK-NEXT:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
 // CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
 // CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
 // CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
 // CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S7]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S9:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
 // CHECK-NEXT:   return %[[S9]] : tensor<2x1x4x2xf32>
 // CHECK-NEXT: }
 
@@ -81,16 +81,16 @@ func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
 // CHECK-NEXT:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
 // CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
 // CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
 // CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
 // CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S7]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
 // CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S9:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
 // CHECK-NEXT:   return %[[S9]] : tensor<2x4x1x2xf32>
 // CHECK-NEXT: }
 
@@ -105,16 +105,16 @@ func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf3
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
 // CHECK-NEXT:  %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
 // CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
 // CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
 // CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
 // CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
 // CHECK-NEXT:  %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
 // CHECK-NEXT:  %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
 // CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK-NEXT:  %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[S9:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
 // CHECK-NEXT:  return %[[S9]] : tensor<2x8x8x2xf32>
 // CHECK-NEXT: }
 
@@ -129,13 +129,13 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
 // CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
 // CHECK-NEXT:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
 // CHECK-NEXT:  ^bb0
 // CHECK-NEXT:    tensor.yield %[[CST]] : f32
 // CHECK-NEXT:  } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
 // CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
 // CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
 // CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
@@ -146,7 +146,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
 // CHECK-NEXT:  ^bb0
 // CHECK-NEXT:    tensor.yield %[[CST]] : f32
 // CHECK-NEXT:  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
 // CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S7]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
 // CHECK-NEXT:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
 // CHECK-NEXT: }
@@ -162,16 +162,16 @@ func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
 // CHECK:        %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-NEXT:   %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
 // CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
 // CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
 // CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
 // CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
 // CHECK-NEXT:   %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
 // CHECK-NEXT:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
 // CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
 // CHECK-NEXT:   return %[[S7]] : tensor<2x4x4x2xf32>
 // CHECK-NEXT: }
 

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 738648b8ccdcf..3160cad9b30bf 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -211,8 +212,8 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
 
 static void applyWinogradConv2D(func::FuncOp funcOp) {
   RewritePatternSet patterns(funcOp.getContext());
-  populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
-  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
+  populateWinogradConv2DPatterns(patterns, WinogradConv2DFmr::F_4_3);
+  populateWinogradConv2DPatterns(patterns, WinogradConv2DFmr::F_2_5);
   (void)applyPatternsGreedily(funcOp, std::move(patterns));
 }
 


        


More information about the Mlir-commits mailing list