[Mlir-commits] [mlir] [mlir][linalg] Constrain the parameters m, r in Winograd ops (PR #144657)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 18 02:15:41 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Hsiangkai Wang (Hsiangkai)
<details>
<summary>Changes</summary>
We only support fixed set of minimum filtering algorithm for Winograd Conv2D decomposition. Instead of letting users specify any integer, define a fixed set of enumeration values for the parameters of minimum filtering algorithm.
---
Patch is 86.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144657.diff
16 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/Linalg.h (+7)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+18)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+6-12)
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2-2)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5-4)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+44-13)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+64-79)
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+15-15)
- (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+12-12)
- (modified) mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (+12-12)
- (modified) mlir/test/Dialect/Linalg/transform-tile-winograd.mlir (+18-18)
- (modified) mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir (+6-6)
- (modified) mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir (+3-3)
- (modified) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+21-21)
- (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+3-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 57bf6305a469d..69e09f6b32c2d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -100,6 +100,13 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"
+namespace mlir {
+namespace linalg {
+WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r);
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
+} // namespace linalg
+} // namespace mlir
+
//===----------------------------------------------------------------------===//
// Linalg Attributes
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index ce68afe471fe8..8c98c0b8b8683 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -122,4 +122,22 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
let cppNamespace = "::mlir::linalg";
}
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+def WinogradConv2DFmr : I32EnumAttr<"WinogradConv2DFmr",
+ "Winograd Conv2D F(m, r)",
+ [
+ I32EnumAttrCase<"F_2_3", 0>,
+ I32EnumAttrCase<"F_4_3", 1>,
+ I32EnumAttrCase<"F_2_5", 2>,
+ I32EnumAttrCase<"Unknown", -1>,
+ ]>{
+ let cppNamespace = "mlir::linalg";
+}
+
#endif // LINALG_ENUMS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 1b48bf5fcb237..7ff44c2e1d2ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -183,15 +183,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
TensorRankOf<[AnyType], [4]>:$output,
- I64Attr:$m,
- I64Attr:$r
+ WinogradConv2DFmr:$fmr
);
let results = (outs TensorRankOf<[AnyType], [4]>:$result);
let assemblyFormat = [{
attr-dict
- `m` `(` $m `)`
- `r` `(` $r `)`
+ `fmr` `(` $fmr `)`
`ins` `(` $filter `:` type($filter) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
@@ -254,15 +252,13 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
TensorRankOf<[AnyType], [6]>:$output,
- I64Attr:$m,
- I64Attr:$r
+ WinogradConv2DFmr:$fmr
);
let results = (outs TensorRankOf<[AnyType], [6]>:$result);
let assemblyFormat = [{
attr-dict
- `m` `(` $m `)`
- `r` `(` $r `)`
+ `fmr` `(` $fmr `)`
`ins` `(` $input `:` type($input) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
@@ -343,15 +339,13 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
TensorRankOf<[AnyType], [4]>:$output,
- I64Attr:$m,
- I64Attr:$r
+ WinogradConv2DFmr:$fmr
);
let results = (outs TensorRankOf<[AnyType], [4]>:$result);
let assemblyFormat = [{
attr-dict
- `m` `(` $m `)`
- `r` `(` $r `)`
+ `fmr` `(` $fmr `)`
`ins` `(` $value `:` type($value) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..1b035ec2a457e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -9,6 +9,7 @@
#ifndef LINALG_TRANSFORM_OPS
#define LINALG_TRANSFORM_OPS
+include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
@@ -2802,8 +2803,7 @@ def WinogradConv2DOp : Op<Transform_Dialect,
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- I64Attr:$m,
- I64Attr:$r);
+ WinogradConv2DFmr:$fmr);
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e5ee7724cd32d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -36,6 +36,7 @@ class BufferizationState;
namespace linalg {
class LinalgOp;
+enum class WinogradConv2DFmr : uint32_t;
//===----------------------------------------------------------------------===//
// Utils.
@@ -1337,8 +1338,8 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
/// F(m x m, r x r). m is the dimension size of output and r is the dimension
/// size of filter.
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
- linalg::Conv2DNhwcFhwcOp op, int64_t m,
- int64_t r);
+ linalg::Conv2DNhwcFhwcOp op,
+ WinogradConv2DFmr fmr);
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
@@ -1879,8 +1880,8 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
- int64_t r);
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
+ WinogradConv2DFmr fmr);
/// Patterns to decompose Winograd operators.
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbb2403eddbd..ac592ad808311 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2989,8 +2989,9 @@ LogicalResult WinogradFilterTransformOp::verify() {
ArrayRef<int64_t> filterShape = filterType.getShape();
int64_t filterH = filterShape[getFilterHDim()];
int64_t filterW = filterShape[getFilterWDim()];
- int64_t r = getR();
- int64_t m = getM();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
if (filterH != r && filterH != 1)
return emitOpError("expect filter height either equals to r or 1");
@@ -3046,8 +3047,9 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
ArrayRef<int64_t> filterShape = filterType.getShape();
int64_t filterH = filterShape[getFilterHDim()];
int64_t filterW = filterShape[getFilterWDim()];
- int64_t m = getM();
- int64_t r = getR();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
int64_t alpha = m + r - 1;
int64_t alphaH = filterH != 1 ? alpha : 1;
int64_t alphaW = filterW != 1 ? alpha : 1;
@@ -3124,8 +3126,9 @@ LogicalResult WinogradInputTransformOp::verify() {
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputH = inputShape[getInputHDim()];
int64_t inputW = inputShape[getInputWDim()];
- int m = getM();
- int r = getR();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
int64_t tileSize = m + r - 1;
auto outputType = cast<ShapedType>(getOutput().getType());
@@ -3194,8 +3197,9 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
- int64_t m = getM();
- int64_t r = getR();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
int64_t alpha = m + r - 1;
int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
@@ -3224,8 +3228,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
- int64_t m = getM();
- int64_t r = getR();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
ShapedType outputType = getOutputOperandType();
ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -3303,8 +3308,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
int64_t valueW = valueShape[getValueAlphaWDim()];
int64_t valueTileH = valueShape[getValueTileHDim()];
int64_t valueTileW = valueShape[getValueTileWDim()];
- int m = getM();
- int r = getR();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
bool leftTransform = valueH != 1;
bool rightTransform = valueW != 1;
@@ -3365,7 +3371,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
- int64_t m = getM();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
Location loc = getLoc();
MLIRContext *context = builder.getContext();
@@ -3623,6 +3631,29 @@ verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
namespace mlir {
namespace linalg {
+WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r) {
+ if (m == 2 && r == 3)
+ return WinogradConv2DFmr::F_2_3;
+ if (m == 4 && r == 3)
+ return WinogradConv2DFmr::F_4_3;
+ if (m == 2 && r == 5)
+ return WinogradConv2DFmr::F_2_5;
+ return WinogradConv2DFmr::Unknown;
+}
+
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
+ switch (fmr) {
+ case WinogradConv2DFmr::F_2_3:
+ return {2, 3};
+ case WinogradConv2DFmr::F_4_3:
+ return {4, 3};
+ case WinogradConv2DFmr::F_2_5:
+ return {2, 5};
+ default:
+ return {-1, -1};
+ }
+}
+
//===----------------------------------------------------------------------===//
// MatMulOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..528c561167824 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4030,7 +4030,7 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
bool supported = TypeSwitch<Operation *, bool>(target)
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
maybeTransformed =
- winogradConv2D(rewriter, op, getM(), getR());
+ winogradConv2D(rewriter, op, getFmr());
return true;
})
.Default([&](Operation *op) { return false; });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index e4221d4748415..a4f835350fb52 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -176,19 +177,6 @@ constexpr float A_2x2_5x5[] = {
};
// clang-format on
-using TransformMapKeyTy = std::pair<int, int>;
-
-/// We use F(m, r) to define the size of minimal filtering algorithms.
-/// m is the output dimension and r is the filter dimension. We can get
-/// the input dimension, alpha, from the formula, alpha = m + r - 1.
-///
-/// For example, when m = 2 and r = 3, we know its input size is 4.
-/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
-/// 2x2 output result.
-constexpr TransformMapKeyTy F_2_3{2, 3};
-constexpr TransformMapKeyTy F_4_3{4, 3};
-constexpr TransformMapKeyTy F_2_5{2, 5};
-
/// Structure to keep information of constant transform matrices.
struct TransformMatrix {
TransformMatrix(const float *table, int64_t rows, int64_t cols,
@@ -344,22 +332,22 @@ Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
/// %ret = linalg.matmul %ret, GT
/// %inserted = insert %ret into filter<h x w x c x f>
Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
- Value retValue, int64_t m, int64_t r,
+ Value retValue, WinogradConv2DFmr fmr,
bool leftTransform = true, bool rightTransform = true) {
// Map from (m, r) to G transform matrix.
- static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
GMatrices = {
- {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
- {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
- {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+ {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+ {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+ {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
};
// Map from (m, r) to GT transform matrix.
- static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
GTMatrices = {
- {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
- {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
- {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+ {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+ {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+ {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
};
auto filterType = cast<ShapedType>(filter.getType());
@@ -370,6 +358,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
int64_t filterW = filterShape[2];
int64_t filterC = filterShape[3];
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
if (filterH != r && filterH != 1)
return Value();
if (filterW != r && filterW != 1)
@@ -387,14 +377,13 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
/*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
- TransformMapKeyTy key = {m, r};
int64_t retRows = 1;
Value matmulRetValue = extractFilter;
Value zero = builder.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix G.
- auto it = GMatrices.find(key);
+ auto it = GMatrices.find(fmr);
if (it == GMatrices.end())
return {};
const TransformMatrix &GMatrix = it->second;
@@ -416,7 +405,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
if (rightTransform) {
// Get constant transform matrix GT.
- auto it = GTMatrices.find(key);
+ auto it = GTMatrices.find(fmr);
if (it == GTMatrices.end())
return {};
const TransformMatrix >Matrix = it->second;
@@ -476,24 +465,26 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
/// %output<alphaH x alphaW x tileH x tileW x N x C>
/// at [0, 0, %h, %w, %n, %c]
Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
- Value retValue, int64_t m, int64_t r,
+ Value retValue, WinogradConv2DFmr fmr,
bool leftTransform = true, bool rightTransform = true) {
// Map from (m, r) to BT transform matrix.
- static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
BTMatrices = {
- {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
- {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
- {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+ {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+ {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+ {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
};
// Map from (m, r) to B transform matrix.
- static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
BMatrices = {
- {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
- {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
- {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+ {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+ {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+ {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
};
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
auto inputType = cast<ShapedType>(input.getType());
Type elementType = inputType.getElementType();
auto inputShape = inputType.getShape(); // N, H, W, C
@@ -529,7 +520,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
/*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
- TransformMapKeyTy key = {m, r};
int64_t retRows = 1;
int64_t retCols = 1;
Value matmulRetValue = extractInput;
@@ -537,7 +527,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
loc, rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix BT.
- auto it = BTMatrices.find(key);
+ auto it = BTMatrices.find(fmr);
if (it == BTMatrices.end())
return {};
const TransformMatrix &BTMatrix = it->second;
@@ -560,7 +550,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
if (rightTransform) {
// Get constant transform matrix B.
- auto it = BMatrices.find(key);
+ auto it = BMatrices.find(fmr);
if (it == BMatrices.end())
return {};
const TransformMatrix &BMatrix = it->second;
@@ -696,24 +686,26 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
/// output<N x H x W x F>
/// at [%n, (%h x m), (%w x m), %f]
Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
- Value output, int64_t m, int64_t r,
+ Value output, WinogradConv2DFmr fmr,
bool l...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/144657
More information about the Mlir-commits
mailing list