[Mlir-commits] [mlir] [mlir][linalg] Constrain the parameters m, r in Winograd ops (PR #144657)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 18 02:15:41 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Hsiangkai Wang (Hsiangkai)

<details>
<summary>Changes</summary>

We only support fixed set of minimum filtering algorithm for Winograd Conv2D decomposition. Instead of letting users specify any integer, define a fixed set of enumeration values for the parameters of minimum filtering algorithm.

---

Patch is 86.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144657.diff


16 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/Linalg.h (+7) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+18) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+6-12) 
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2-2) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5-4) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+44-13) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+64-79) 
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+15-15) 
- (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+12-12) 
- (modified) mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (+12-12) 
- (modified) mlir/test/Dialect/Linalg/transform-tile-winograd.mlir (+18-18) 
- (modified) mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir (+6-6) 
- (modified) mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir (+3-3) 
- (modified) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+21-21) 
- (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+3-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 57bf6305a469d..69e09f6b32c2d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -100,6 +100,13 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
 
 #include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"
 
+namespace mlir {
+namespace linalg {
+WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r);
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
+} // namespace linalg
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // Linalg Attributes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index ce68afe471fe8..8c98c0b8b8683 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -122,4 +122,22 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
   let cppNamespace = "::mlir::linalg";
 }
 
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+def WinogradConv2DFmr : I32EnumAttr<"WinogradConv2DFmr",
+    "Winograd Conv2D F(m, r)",
+    [
+      I32EnumAttrCase<"F_2_3", 0>,
+      I32EnumAttrCase<"F_4_3", 1>,
+      I32EnumAttrCase<"F_2_5", 2>,
+      I32EnumAttrCase<"Unknown", -1>,
+    ]>{
+  let cppNamespace = "mlir::linalg";
+}
+
 #endif // LINALG_ENUMS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 1b48bf5fcb237..7ff44c2e1d2ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -183,15 +183,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
                        TensorRankOf<[AnyType], [4]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [4]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $filter `:` type($filter) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
@@ -254,15 +252,13 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
                        TensorRankOf<[AnyType], [6]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [6]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $input `:` type($input) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
@@ -343,15 +339,13 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
                        TensorRankOf<[AnyType], [4]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [4]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $value `:` type($value) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..1b035ec2a457e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -9,6 +9,7 @@
 #ifndef LINALG_TRANSFORM_OPS
 #define LINALG_TRANSFORM_OPS
 
+include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
 include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
@@ -2802,8 +2803,7 @@ def WinogradConv2DOp : Op<Transform_Dialect,
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       I64Attr:$m,
-                       I64Attr:$r);
+                       WinogradConv2DFmr:$fmr);
   let results = (outs TransformHandleTypeInterface:$transformed);
 
   let assemblyFormat =
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e5ee7724cd32d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -36,6 +36,7 @@ class BufferizationState;
 namespace linalg {
 
 class LinalgOp;
+enum class WinogradConv2DFmr : uint32_t;
 
 //===----------------------------------------------------------------------===//
 // Utils.
@@ -1337,8 +1338,8 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
 /// F(m x m, r x r). m is the dimension size of output and r is the dimension
 /// size of filter.
 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
-                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
-                                      int64_t r);
+                                      linalg::Conv2DNhwcFhwcOp op,
+                                      WinogradConv2DFmr fmr);
 
 /// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
 /// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
@@ -1879,8 +1880,8 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
 /// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
-                                    int64_t r);
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
+                                    WinogradConv2DFmr fmr);
 
 /// Patterns to decompose Winograd operators.
 void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbb2403eddbd..ac592ad808311 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2989,8 +2989,9 @@ LogicalResult WinogradFilterTransformOp::verify() {
   ArrayRef<int64_t> filterShape = filterType.getShape();
   int64_t filterH = filterShape[getFilterHDim()];
   int64_t filterW = filterShape[getFilterWDim()];
-  int64_t r = getR();
-  int64_t m = getM();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   if (filterH != r && filterH != 1)
     return emitOpError("expect filter height either equals to r or 1");
@@ -3046,8 +3047,9 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
   ArrayRef<int64_t> filterShape = filterType.getShape();
   int64_t filterH = filterShape[getFilterHDim()];
   int64_t filterW = filterShape[getFilterWDim()];
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t alpha = m + r - 1;
   int64_t alphaH = filterH != 1 ? alpha : 1;
   int64_t alphaW = filterW != 1 ? alpha : 1;
@@ -3124,8 +3126,9 @@ LogicalResult WinogradInputTransformOp::verify() {
   ArrayRef<int64_t> inputShape = inputType.getShape();
   int64_t inputH = inputShape[getInputHDim()];
   int64_t inputW = inputShape[getInputWDim()];
-  int m = getM();
-  int r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t tileSize = m + r - 1;
 
   auto outputType = cast<ShapedType>(getOutput().getType());
@@ -3194,8 +3197,9 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
   int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
   int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
 
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t alpha = m + r - 1;
   int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
   int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
@@ -3224,8 +3228,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   ShapedType outputType = getOutputOperandType();
   ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -3303,8 +3308,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
   int64_t valueW = valueShape[getValueAlphaWDim()];
   int64_t valueTileH = valueShape[getValueTileHDim()];
   int64_t valueTileW = valueShape[getValueTileWDim()];
-  int m = getM();
-  int r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   bool leftTransform = valueH != 1;
   bool rightTransform = valueW != 1;
 
@@ -3365,7 +3371,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
-  int64_t m = getM();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
@@ -3623,6 +3631,29 @@ verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
 namespace mlir {
 namespace linalg {
 
+WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r) {
+  if (m == 2 && r == 3)
+    return WinogradConv2DFmr::F_2_3;
+  if (m == 4 && r == 3)
+    return WinogradConv2DFmr::F_4_3;
+  if (m == 2 && r == 5)
+    return WinogradConv2DFmr::F_2_5;
+  return WinogradConv2DFmr::Unknown;
+}
+
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
+  switch (fmr) {
+  case WinogradConv2DFmr::F_2_3:
+    return {2, 3};
+  case WinogradConv2DFmr::F_4_3:
+    return {4, 3};
+  case WinogradConv2DFmr::F_2_5:
+    return {2, 5};
+  default:
+    return {-1, -1};
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // MatMulOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..528c561167824 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4030,7 +4030,7 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
   bool supported = TypeSwitch<Operation *, bool>(target)
                        .Case([&](linalg::Conv2DNhwcFhwcOp op) {
                          maybeTransformed =
-                             winogradConv2D(rewriter, op, getM(), getR());
+                             winogradConv2D(rewriter, op, getFmr());
                          return true;
                        })
                        .Default([&](Operation *op) { return false; });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index e4221d4748415..a4f835350fb52 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -176,19 +177,6 @@ constexpr float A_2x2_5x5[] = {
 };
 // clang-format on
 
-using TransformMapKeyTy = std::pair<int, int>;
-
-/// We use F(m, r) to define the size of minimal filtering algorithms.
-/// m is the output dimension and r is the filter dimension. We can get
-/// the input dimension, alpha, from the formula, alpha = m + r - 1.
-///
-/// For example, when m = 2 and r = 3, we know its input size is 4.
-/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
-/// 2x2 output result.
-constexpr TransformMapKeyTy F_2_3{2, 3};
-constexpr TransformMapKeyTy F_4_3{4, 3};
-constexpr TransformMapKeyTy F_2_5{2, 5};
-
 /// Structure to keep information of constant transform matrices.
 struct TransformMatrix {
   TransformMatrix(const float *table, int64_t rows, int64_t cols,
@@ -344,22 +332,22 @@ Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
 ///     %ret = linalg.matmul %ret, GT
 ///     %inserted = insert %ret into filter<h x w x c x f>
 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
-                      Value retValue, int64_t m, int64_t r,
+                      Value retValue, WinogradConv2DFmr fmr,
                       bool leftTransform = true, bool rightTransform = true) {
   // Map from (m, r) to G transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       GMatrices = {
-          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
-          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
-          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
       };
 
   // Map from (m, r) to GT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       GTMatrices = {
-          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
-          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
-          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
       };
 
   auto filterType = cast<ShapedType>(filter.getType());
@@ -370,6 +358,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
   int64_t filterW = filterShape[2];
   int64_t filterC = filterShape[3];
 
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   if (filterH != r && filterH != 1)
     return Value();
   if (filterW != r && filterW != 1)
@@ -387,14 +377,13 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
                             zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
 
-    TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
     Value matmulRetValue = extractFilter;
     Value zero = builder.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix G.
-      auto it = GMatrices.find(key);
+      auto it = GMatrices.find(fmr);
       if (it == GMatrices.end())
         return {};
       const TransformMatrix &GMatrix = it->second;
@@ -416,7 +405,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 
     if (rightTransform) {
       // Get constant transform matrix GT.
-      auto it = GTMatrices.find(key);
+      auto it = GTMatrices.find(fmr);
       if (it == GTMatrices.end())
         return {};
       const TransformMatrix &GTMatrix = it->second;
@@ -476,24 +465,26 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 ///                            %output<alphaH x alphaW x tileH x tileW x N x C>
 ///                            at [0, 0, %h, %w, %n, %c]
 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
-                     Value retValue, int64_t m, int64_t r,
+                     Value retValue, WinogradConv2DFmr fmr,
                      bool leftTransform = true, bool rightTransform = true) {
   // Map from (m, r) to BT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       BTMatrices = {
-          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
-          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
-          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
       };
 
   // Map from (m, r) to B transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       BMatrices = {
-          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
-          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
-          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
       };
 
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   auto inputType = cast<ShapedType>(input.getType());
   Type elementType = inputType.getElementType();
   auto inputShape = inputType.getShape(); // N, H, W, C
@@ -529,7 +520,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
                             widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
 
-    TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
     int64_t retCols = 1;
     Value matmulRetValue = extractInput;
@@ -537,7 +527,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
         loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix BT.
-      auto it = BTMatrices.find(key);
+      auto it = BTMatrices.find(fmr);
       if (it == BTMatrices.end())
         return {};
       const TransformMatrix &BTMatrix = it->second;
@@ -560,7 +550,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
 
     if (rightTransform) {
       // Get constant transform matrix B.
-      auto it = BMatrices.find(key);
+      auto it = BMatrices.find(fmr);
       if (it == BMatrices.end())
         return {};
       const TransformMatrix &BMatrix = it->second;
@@ -696,24 +686,26 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
 ///                            output<N x H x W x F>
 ///                            at [%n, (%h x m), (%w x m), %f]
 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
-                      Value output, int64_t m, int64_t r,
+                      Value output, WinogradConv2DFmr fmr,
                       bool l...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/144657


More information about the Mlir-commits mailing list