[Mlir-commits] [mlir] [mlir][linalg] Constrain the parameters m, r in Winograd ops (PR #144657)
Hsiangkai Wang
llvmlistbot at llvm.org
Wed Jun 25 02:32:30 PDT 2025
https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/144657
>From 50e729fb94901f2c415f383f6b25d216d737f986 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Tue, 17 Jun 2025 23:33:17 +0100
Subject: [PATCH 1/4] [mlir][linalg] Constrain the parameters m, r in Winograd
ops
We only support fixed set of minimum filtering algorithm for Winograd
Conv2D decomposition. Instead of letting users specify any integer,
define a fixed set of enumeration values for the parameters of minimum
filtering algorithm.
---
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 7 +
.../mlir/Dialect/Linalg/IR/LinalgEnums.td | 18 +++
.../mlir/Dialect/Linalg/IR/LinalgOps.td | 18 +--
.../Linalg/TransformOps/LinalgTransformOps.td | 4 +-
.../Dialect/Linalg/Transforms/Transforms.h | 9 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 57 +++++--
.../TransformOps/LinalgTransformOps.cpp | 2 +-
.../Linalg/Transforms/WinogradConv2D.cpp | 143 ++++++++----------
mlir/test/Dialect/Linalg/invalid.mlir | 30 ++--
mlir/test/Dialect/Linalg/roundtrip.mlir | 24 +--
.../transform-tile-and-winograd-rewrite.mlir | 24 +--
.../Linalg/transform-tile-winograd.mlir | 36 ++---
.../Linalg/transform-winograd-conv2d.mlir | 12 +-
.../Linalg/winograd-conv2d-rewrite.mlir | 6 +-
mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 42 ++---
.../Dialect/Linalg/TestLinalgTransforms.cpp | 5 +-
16 files changed, 237 insertions(+), 200 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 57bf6305a469d..69e09f6b32c2d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -100,6 +100,13 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"
+namespace mlir {
+namespace linalg {
+WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r);
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
+} // namespace linalg
+} // namespace mlir
+
//===----------------------------------------------------------------------===//
// Linalg Attributes
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index ce68afe471fe8..8c98c0b8b8683 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -122,4 +122,22 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
let cppNamespace = "::mlir::linalg";
}
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+def WinogradConv2DFmr : I32EnumAttr<"WinogradConv2DFmr",
+ "Winograd Conv2D F(m, r)",
+ [
+ I32EnumAttrCase<"F_2_3", 0>,
+ I32EnumAttrCase<"F_4_3", 1>,
+ I32EnumAttrCase<"F_2_5", 2>,
+ I32EnumAttrCase<"Unknown", -1>,
+ ]>{
+ let cppNamespace = "mlir::linalg";
+}
+
#endif // LINALG_ENUMS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 1b48bf5fcb237..7ff44c2e1d2ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -183,15 +183,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
TensorRankOf<[AnyType], [4]>:$output,
- I64Attr:$m,
- I64Attr:$r
+ WinogradConv2DFmr:$fmr
);
let results = (outs TensorRankOf<[AnyType], [4]>:$result);
let assemblyFormat = [{
attr-dict
- `m` `(` $m `)`
- `r` `(` $r `)`
+ `fmr` `(` $fmr `)`
`ins` `(` $filter `:` type($filter) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
@@ -254,15 +252,13 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
TensorRankOf<[AnyType], [6]>:$output,
- I64Attr:$m,
- I64Attr:$r
+ WinogradConv2DFmr:$fmr
);
let results = (outs TensorRankOf<[AnyType], [6]>:$result);
let assemblyFormat = [{
attr-dict
- `m` `(` $m `)`
- `r` `(` $r `)`
+ `fmr` `(` $fmr `)`
`ins` `(` $input `:` type($input) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
@@ -343,15 +339,13 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
TensorRankOf<[AnyType], [4]>:$output,
- I64Attr:$m,
- I64Attr:$r
+ WinogradConv2DFmr:$fmr
);
let results = (outs TensorRankOf<[AnyType], [4]>:$result);
let assemblyFormat = [{
attr-dict
- `m` `(` $m `)`
- `r` `(` $r `)`
+ `fmr` `(` $fmr `)`
`ins` `(` $value `:` type($value) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..1b035ec2a457e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -9,6 +9,7 @@
#ifndef LINALG_TRANSFORM_OPS
#define LINALG_TRANSFORM_OPS
+include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
@@ -2802,8 +2803,7 @@ def WinogradConv2DOp : Op<Transform_Dialect,
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- I64Attr:$m,
- I64Attr:$r);
+ WinogradConv2DFmr:$fmr);
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e5ee7724cd32d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -36,6 +36,7 @@ class BufferizationState;
namespace linalg {
class LinalgOp;
+enum class WinogradConv2DFmr : uint32_t;
//===----------------------------------------------------------------------===//
// Utils.
@@ -1337,8 +1338,8 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
/// F(m x m, r x r). m is the dimension size of output and r is the dimension
/// size of filter.
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
- linalg::Conv2DNhwcFhwcOp op, int64_t m,
- int64_t r);
+ linalg::Conv2DNhwcFhwcOp op,
+ WinogradConv2DFmr fmr);
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
@@ -1879,8 +1880,8 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
- int64_t r);
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
+ WinogradConv2DFmr fmr);
/// Patterns to decompose Winograd operators.
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbb2403eddbd..ac592ad808311 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2989,8 +2989,9 @@ LogicalResult WinogradFilterTransformOp::verify() {
ArrayRef<int64_t> filterShape = filterType.getShape();
int64_t filterH = filterShape[getFilterHDim()];
int64_t filterW = filterShape[getFilterWDim()];
- int64_t r = getR();
- int64_t m = getM();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
if (filterH != r && filterH != 1)
return emitOpError("expect filter height either equals to r or 1");
@@ -3046,8 +3047,9 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
ArrayRef<int64_t> filterShape = filterType.getShape();
int64_t filterH = filterShape[getFilterHDim()];
int64_t filterW = filterShape[getFilterWDim()];
- int64_t m = getM();
- int64_t r = getR();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
int64_t alpha = m + r - 1;
int64_t alphaH = filterH != 1 ? alpha : 1;
int64_t alphaW = filterW != 1 ? alpha : 1;
@@ -3124,8 +3126,9 @@ LogicalResult WinogradInputTransformOp::verify() {
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputH = inputShape[getInputHDim()];
int64_t inputW = inputShape[getInputWDim()];
- int m = getM();
- int r = getR();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
int64_t tileSize = m + r - 1;
auto outputType = cast<ShapedType>(getOutput().getType());
@@ -3194,8 +3197,9 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
- int64_t m = getM();
- int64_t r = getR();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
int64_t alpha = m + r - 1;
int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
@@ -3224,8 +3228,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
- int64_t m = getM();
- int64_t r = getR();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
ShapedType outputType = getOutputOperandType();
ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -3303,8 +3308,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
int64_t valueW = valueShape[getValueAlphaWDim()];
int64_t valueTileH = valueShape[getValueTileHDim()];
int64_t valueTileW = valueShape[getValueTileWDim()];
- int m = getM();
- int r = getR();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
bool leftTransform = valueH != 1;
bool rightTransform = valueW != 1;
@@ -3365,7 +3371,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
- int64_t m = getM();
+ WinogradConv2DFmr fmr = getFmr();
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
Location loc = getLoc();
MLIRContext *context = builder.getContext();
@@ -3623,6 +3631,29 @@ verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
namespace mlir {
namespace linalg {
+WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r) {
+ if (m == 2 && r == 3)
+ return WinogradConv2DFmr::F_2_3;
+ if (m == 4 && r == 3)
+ return WinogradConv2DFmr::F_4_3;
+ if (m == 2 && r == 5)
+ return WinogradConv2DFmr::F_2_5;
+ return WinogradConv2DFmr::Unknown;
+}
+
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
+ switch (fmr) {
+ case WinogradConv2DFmr::F_2_3:
+ return {2, 3};
+ case WinogradConv2DFmr::F_4_3:
+ return {4, 3};
+ case WinogradConv2DFmr::F_2_5:
+ return {2, 5};
+ default:
+ return {-1, -1};
+ }
+}
+
//===----------------------------------------------------------------------===//
// MatMulOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..528c561167824 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4030,7 +4030,7 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
bool supported = TypeSwitch<Operation *, bool>(target)
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
maybeTransformed =
- winogradConv2D(rewriter, op, getM(), getR());
+ winogradConv2D(rewriter, op, getFmr());
return true;
})
.Default([&](Operation *op) { return false; });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index e4221d4748415..a4f835350fb52 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -176,19 +177,6 @@ constexpr float A_2x2_5x5[] = {
};
// clang-format on
-using TransformMapKeyTy = std::pair<int, int>;
-
-/// We use F(m, r) to define the size of minimal filtering algorithms.
-/// m is the output dimension and r is the filter dimension. We can get
-/// the input dimension, alpha, from the formula, alpha = m + r - 1.
-///
-/// For example, when m = 2 and r = 3, we know its input size is 4.
-/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
-/// 2x2 output result.
-constexpr TransformMapKeyTy F_2_3{2, 3};
-constexpr TransformMapKeyTy F_4_3{4, 3};
-constexpr TransformMapKeyTy F_2_5{2, 5};
-
/// Structure to keep information of constant transform matrices.
struct TransformMatrix {
TransformMatrix(const float *table, int64_t rows, int64_t cols,
@@ -344,22 +332,22 @@ Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
/// %ret = linalg.matmul %ret, GT
/// %inserted = insert %ret into filter<h x w x c x f>
Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
- Value retValue, int64_t m, int64_t r,
+ Value retValue, WinogradConv2DFmr fmr,
bool leftTransform = true, bool rightTransform = true) {
// Map from (m, r) to G transform matrix.
- static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
GMatrices = {
- {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
- {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
- {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+ {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+ {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+ {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
};
// Map from (m, r) to GT transform matrix.
- static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
GTMatrices = {
- {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
- {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
- {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+ {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+ {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+ {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
};
auto filterType = cast<ShapedType>(filter.getType());
@@ -370,6 +358,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
int64_t filterW = filterShape[2];
int64_t filterC = filterShape[3];
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
if (filterH != r && filterH != 1)
return Value();
if (filterW != r && filterW != 1)
@@ -387,14 +377,13 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
/*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
- TransformMapKeyTy key = {m, r};
int64_t retRows = 1;
Value matmulRetValue = extractFilter;
Value zero = builder.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix G.
- auto it = GMatrices.find(key);
+ auto it = GMatrices.find(fmr);
if (it == GMatrices.end())
return {};
const TransformMatrix &GMatrix = it->second;
@@ -416,7 +405,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
if (rightTransform) {
// Get constant transform matrix GT.
- auto it = GTMatrices.find(key);
+ auto it = GTMatrices.find(fmr);
if (it == GTMatrices.end())
return {};
const TransformMatrix >Matrix = it->second;
@@ -476,24 +465,26 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
/// %output<alphaH x alphaW x tileH x tileW x N x C>
/// at [0, 0, %h, %w, %n, %c]
Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
- Value retValue, int64_t m, int64_t r,
+ Value retValue, WinogradConv2DFmr fmr,
bool leftTransform = true, bool rightTransform = true) {
// Map from (m, r) to BT transform matrix.
- static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
BTMatrices = {
- {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
- {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
- {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+ {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+ {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+ {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
};
// Map from (m, r) to B transform matrix.
- static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
BMatrices = {
- {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
- {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
- {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+ {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+ {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+ {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
};
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
auto inputType = cast<ShapedType>(input.getType());
Type elementType = inputType.getElementType();
auto inputShape = inputType.getShape(); // N, H, W, C
@@ -529,7 +520,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
/*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
- TransformMapKeyTy key = {m, r};
int64_t retRows = 1;
int64_t retCols = 1;
Value matmulRetValue = extractInput;
@@ -537,7 +527,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
loc, rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix BT.
- auto it = BTMatrices.find(key);
+ auto it = BTMatrices.find(fmr);
if (it == BTMatrices.end())
return {};
const TransformMatrix &BTMatrix = it->second;
@@ -560,7 +550,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
if (rightTransform) {
// Get constant transform matrix B.
- auto it = BMatrices.find(key);
+ auto it = BMatrices.find(fmr);
if (it == BMatrices.end())
return {};
const TransformMatrix &BMatrix = it->second;
@@ -696,24 +686,26 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
/// output<N x H x W x F>
/// at [%n, (%h x m), (%w x m), %f]
Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
- Value output, int64_t m, int64_t r,
+ Value output, WinogradConv2DFmr fmr,
bool leftTransform = true, bool rightTransform = true) {
// Map from (m, r) to AT transform matrix.
- static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
ATMatrices = {
- {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
- {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
- {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+ {WinogradConv2DFmr::F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+ {WinogradConv2DFmr::F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+ {WinogradConv2DFmr::F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
};
// Map from (m, r) to A transform matrix.
- static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
AMatrices = {
- {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
- {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
- {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+ {WinogradConv2DFmr::F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+ {WinogradConv2DFmr::F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+ {WinogradConv2DFmr::F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
};
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
auto valueType = cast<ShapedType>(value.getType());
Type elementType = valueType.getElementType();
auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
@@ -743,9 +735,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
FIter, 2, 3, /*loopNorFIdx=*/4,
/*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
- const TransformMapKeyTy key = {m, r};
- const TransformMatrix &AMatrix = AMatrices.at(key);
- const TransformMatrix &ATMatrix = ATMatrices.at(key);
+ const TransformMatrix &AMatrix = AMatrices.at(fmr);
+ const TransformMatrix &ATMatrix = ATMatrices.at(fmr);
int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
(leftTransform ? ATMatrix.scalarFactor : 1);
int64_t retCols = rightTransform ? AMatrix.cols : 1;
@@ -903,7 +894,7 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) {
/// linalg.winograd_*_transform ops.
static FailureOr<Operation *>
winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
- int64_t m, int64_t r) {
+ WinogradConv2DFmr fmr) {
if (!convOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
convOp, "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
@@ -946,6 +937,8 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
int64_t outputW = outputShape[2];
int64_t outputF = outputShape[3];
+ int64_t m, r;
+ std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
// Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
bool isSupportedFilter = false;
if (filterH == filterW && filterH == r)
@@ -959,15 +952,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
return rewriter.notifyMatchFailure(
convOp, "only support filter (r x r), (r x 1) or (1 x r)");
- // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
- static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
- F_2_3, F_4_3, F_2_5};
-
- TransformMapKeyTy key = {m, r};
- auto it = llvm::find(validConfigs, key);
- // If we cannot find the constant transformation matrix, it means we do
- // not support this configuration yet.
- if (it == validConfigs.end())
+ if (fmr == WinogradConv2DFmr::Unknown)
return failure();
// All the criterias are satisfied. We can do Winograd Conv2D.
@@ -993,7 +978,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
filterElementType);
auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
- loc, retType, filter, retValue, m, r);
+ loc, retType, filter, retValue, fmr);
// --- Create operation for input transform ---
@@ -1012,7 +997,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
inputElementType);
auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
- loc, retType, input, retValue, m, r);
+ loc, retType, input, retValue, fmr);
Type outputElementType = outputType.getElementType();
Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
@@ -1035,7 +1020,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
}
Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
- loc, outputType, matmulRet, output, m, r);
+ loc, outputType, matmulRet, output, fmr);
// When output size is not aligned with output tile size, extract the
// value from the padded buffer.
@@ -1067,8 +1052,8 @@ decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
// For F(1 x m, 1 x r), we only need to do right side transform.
bool rightTransform = filterW != 1;
Value transformedFilter =
- filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
- op.getR(), leftTransform, rightTransform);
+ filterTransform(rewriter, loc, filter, op.getOutput(), op.getFmr(),
+ leftTransform, rightTransform);
if (!transformedFilter)
return failure();
@@ -1094,8 +1079,8 @@ decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
// For F(1 x m, 1 x r), we only need to do right side transform.
bool rightTransform = outputW != 1;
Value transformedInput =
- inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
- op.getR(), leftTransform, rightTransform);
+ inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getFmr(),
+ leftTransform, rightTransform);
if (!transformedInput)
return failure();
@@ -1120,8 +1105,8 @@ decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
// For F(1 x m, 1 x r), we only need to do right side transform.
bool rightTransform = valueW != 1;
Value transformedOutput =
- outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
- op.getR(), leftTransform, rightTransform);
+ outputTransform(rewriter, loc, value, op.getOutput(), op.getFmr(),
+ leftTransform, rightTransform);
if (!transformedOutput)
return failure();
@@ -1171,28 +1156,28 @@ class WinogradConv2DNhwcFhwc final
: public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
public:
using OpRewritePattern::OpRewritePattern;
- WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
- : OpRewritePattern(context), m(m), r(r) {}
+ WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, WinogradConv2DFmr fmr)
+ : OpRewritePattern(context), fmr(fmr) {}
LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
PatternRewriter &rewriter) const override {
- if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
+ if (failed(winogradConv2DHelper(rewriter, convOp, fmr)))
return failure();
return success();
}
private:
- int64_t m;
- int64_t r;
+ WinogradConv2DFmr fmr;
};
+
} // end anonymous namespace
//===----------------------------------------------------------------------===//
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
- linalg::Conv2DNhwcFhwcOp op, int64_t m,
- int64_t r) {
- return winogradConv2DHelper(rewriter, op, m, r);
+ linalg::Conv2DNhwcFhwcOp op,
+ linalg::WinogradConv2DFmr fmr) {
+ return winogradConv2DHelper(rewriter, op, fmr);
}
FailureOr<Operation *>
@@ -1213,11 +1198,11 @@ decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
return decomposeWinogradOutputTransformHelper(rewriter, op);
}
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
- int64_t r) {
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
+ WinogradConv2DFmr fmr) {
MLIRContext *context = patterns.getContext();
// TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
- patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
+ patterns.insert<WinogradConv2DNhwcFhwc>(context, fmr);
}
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c0c5f785e856b..f3a59091fe437 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1165,7 +1165,7 @@ func.func @mixed_semantics(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>, %c: memref<
func.func @winograd_filter_transform_height(%arg0: tensor<2x4x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
// expected-error @+1 {{expect filter height either equals to r or 1}}
- %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x4x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x4x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
return %0 : tensor<6x6x5x2xf32>
}
@@ -1173,7 +1173,7 @@ func.func @winograd_filter_transform_height(%arg0: tensor<2x4x3x5xf32>, %arg1: t
func.func @winograd_filter_transform_width(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
// expected-error @+1 {{expect filter width either equals to r or 1}}
- %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x4x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x3x4x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
return %0 : tensor<6x6x5x2xf32>
}
@@ -1181,7 +1181,7 @@ func.func @winograd_filter_transform_width(%arg0: tensor<2x3x4x5xf32>, %arg1: te
func.func @winograd_filter_transform(%arg0: tensor<2x1x1x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
// expected-error @+1 {{expect either filter height or width equals to r}}
- %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x1x1x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x1x1x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
return %0 : tensor<6x6x5x2xf32>
}
@@ -1189,7 +1189,7 @@ func.func @winograd_filter_transform(%arg0: tensor<2x1x1x5xf32>, %arg1: tensor<6
func.func @winograd_filter_dyn(%arg0: tensor<?x3x3x?xf32>, %arg1: tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32> {
// expected-error @+1 {{the output shape is not expected}}
- %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32>
+ %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32>
return %0 : tensor<6x5x?x?xf32>
}
@@ -1197,7 +1197,7 @@ func.func @winograd_filter_dyn(%arg0: tensor<?x3x3x?xf32>, %arg1: tensor<6x5x?x?
func.func @winograd_input_transform_height(%arg0: tensor<2x13x14x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> {
// expected-error @+1 {{the output shape is not expected}}
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x13x14x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x13x14x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
return %0 : tensor<6x6x3x3x2x5xf32>
}
@@ -1205,7 +1205,7 @@ func.func @winograd_input_transform_height(%arg0: tensor<2x13x14x5xf32>, %arg1:
func.func @winograd_input_transform_width(%arg0: tensor<2x14x13x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> {
// expected-error @+1 {{the output shape is not expected}}
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x13x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x14x13x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
return %0 : tensor<6x6x3x3x2x5xf32>
}
@@ -1213,7 +1213,7 @@ func.func @winograd_input_transform_width(%arg0: tensor<2x14x13x5xf32>, %arg1: t
func.func @winograd_input_transform_output_tileH(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32> {
// expected-error @+1 {{the output shape is not expected}}
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32>
return %0 : tensor<6x6x2x3x2x5xf32>
}
@@ -1221,7 +1221,7 @@ func.func @winograd_input_transform_output_tileH(%arg0: tensor<2x14x14x5xf32>, %
func.func @winograd_input_transform_output_tileW(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32> {
// expected-error @+1 {{the output shape is not expected}}
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32>
return %0 : tensor<6x6x3x2x2x5xf32>
}
@@ -1229,7 +1229,7 @@ func.func @winograd_input_transform_output_tileW(%arg0: tensor<2x14x14x5xf32>, %
func.func @winograd_input_transform_output_height(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32> {
// expected-error @+1 {{the output shape is not expected}}
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32>
return %0 : tensor<5x6x3x3x2x5xf32>
}
@@ -1237,7 +1237,7 @@ func.func @winograd_input_transform_output_height(%arg0: tensor<2x14x14x5xf32>,
func.func @winograd_input_transform_output_width(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32> {
// expected-error @+1 {{the output shape is not expected}}
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32>
return %0 : tensor<6x5x3x3x2x5xf32>
}
@@ -1245,7 +1245,7 @@ func.func @winograd_input_transform_output_width(%arg0: tensor<2x14x14x5xf32>, %
func.func @winograd_input_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32> {
// expected-error @+1 {{the output shape is not expected}}
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32>
return %0 : tensor<6x5x?x?x?x?xf32>
}
@@ -1253,7 +1253,7 @@ func.func @winograd_input_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<6x5x?x?x
func.func @winograd_output_transform_input_height(%arg0: tensor<5x6x3x3x2x2xf32>, %arg1: tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> {
// expected-error @+1 {{expect input height equals to input tile size}}
- %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<5x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+ %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<5x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
return %0 : tensor<2x12x12x2xf32>
}
@@ -1261,7 +1261,7 @@ func.func @winograd_output_transform_input_height(%arg0: tensor<5x6x3x3x2x2xf32>
func.func @winograd_output_transform_input_width(%arg0: tensor<6x5x3x3x2x2xf32>, %arg1: tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> {
// expected-error @+1 {{expect input width equals to input tile size}}
- %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x5x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+ %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x5x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
return %0 : tensor<2x12x12x2xf32>
}
@@ -1269,7 +1269,7 @@ func.func @winograd_output_transform_input_width(%arg0: tensor<6x5x3x3x2x2xf32>,
func.func @winograd_output_transform_output_height(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32> {
// expected-error @+1 {{the output shape is not expected}}
- %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32>
+ %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32>
return %0 : tensor<2x11x12x2xf32>
}
@@ -1277,7 +1277,7 @@ func.func @winograd_output_transform_output_height(%arg0: tensor<6x6x3x3x2x2xf32
func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32> {
// expected-error @+1 {{the output shape is not expected}}
- %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
+ %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
return %0 : tensor<2x12x11x2xf32>
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index dc556761b09e5..4edbc6eda3eae 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -630,52 +630,52 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
func.func @winograd(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
%0 = tensor.empty() : tensor<6x6x5x2xf32>
- %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
%2 = tensor.empty() : tensor<6x6x1x1x2x5xf32>
- %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%2 : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+ %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%2 : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
%collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
%collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
%4 = tensor.empty() : tensor<36x2x2xf32>
%5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
%expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
- %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+ %6 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
return %6 : tensor<2x4x4x2xf32>
}
// CHECK-LABEL: func @winograd
-// CHECK: linalg.winograd_filter_transform m(4) r(3)
-// CHECK: linalg.winograd_input_transform m(4) r(3)
-// CHECK: linalg.winograd_output_transform m(4) r(3)
+// CHECK: linalg.winograd_filter_transform fmr(F_4_3)
+// CHECK: linalg.winograd_input_transform fmr(F_4_3)
+// CHECK: linalg.winograd_output_transform fmr(F_4_3)
// -----
func.func @winograd_filter_dyn(%arg0: tensor<?x3x3x?xf32>, %arg1: tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32> {
- %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32>
+ %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32>
return %0 : tensor<6x6x?x?xf32>
}
// CHECK-LABEL: func @winograd_filter_dyn
-// CHECK: linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32>
+// CHECK: linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32>
// -----
func.func @winograd_input_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32> {
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32>
return %0 : tensor<6x6x?x?x?x?xf32>
}
// CHECK-LABEL: func @winograd_input_dyn
-// CHECK: linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32>
+// CHECK: linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32>
// -----
func.func @winograd_output_dyn(%arg0: tensor<6x6x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
- %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}
// CHECK-LABEL: func @winograd_output_dyn
-// CHECK: linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK: linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// -----
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index cdc4b8a72a276..fe5124eb251d9 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -2,15 +2,15 @@
func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
%0 = tensor.empty() : tensor<6x6x5x2xf32>
- %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
%2 = tensor.empty() : tensor<6x6x2x2x2x5xf32>
- %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
%collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
%collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
%4 = tensor.empty() : tensor<36x8x2xf32>
%5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
%expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
- %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ %6 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
return %6 : tensor<2x8x8x2xf32>
}
@@ -123,13 +123,13 @@ module attributes {transform.with_named_sequence} {
func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<6x6x5x2xf32>
- %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
%padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
tensor.yield %cst : f32
} : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
%2 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
- %3 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+ %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
%collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
%collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
%4 = tensor.empty() : tensor<36x18x2xf32>
@@ -140,7 +140,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
tensor.yield %cst : f32
} : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
- %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+ %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
%extracted_slice = tensor.extract_slice %7[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
return %extracted_slice : tensor<2x9x9x2xf32>
}
@@ -259,16 +259,16 @@ module attributes {transform.with_named_sequence} {
func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<6x1x5x2xf32>
- %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
%2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
- %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+ %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
%collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
%collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
%4 = tensor.empty() : tensor<6x2x2xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
%6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
%expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
- %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+ %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
return %7 : tensor<2x4x1x2xf32>
}
@@ -350,16 +350,16 @@ module attributes {transform.with_named_sequence} {
func.func @conv2d_mx1_rx1_2(%arg0: tensor<2x6x2x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<6x1x5x2xf32>
- %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ %1 = linalg.winograd_filter_transform mfr(F_4_3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
%2 = tensor.empty() : tensor<6x1x1x2x2x5xf32>
- %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
+ %3 = linalg.winograd_input_transform mfr(F_4_3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
%collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
%collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x2x2x5xf32> into tensor<6x4x5xf32>
%4 = tensor.empty() : tensor<6x4x2xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
%6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
%expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2] : tensor<6x4x2xf32> into tensor<6x1x1x2x2x2xf32>
- %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
+ %7 = linalg.winograd_output_transform mfr(F_4_3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
return %7 : tensor<2x4x2x2xf32>
}
diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
index fc6424fd4c812..688264cc9a7aa 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -transform-interpreter --split-input-file | FileCheck %s
func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
- %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %0 = linalg.winograd_filter_transform mfr(F_4_3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
return %0 : tensor<6x6x5x2xf32>
}
@@ -25,13 +25,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
// CHECK: %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
-// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
+// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S3]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
// -----
func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
- %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %0 = linalg.winograd_filter_transform mfr(F_4_3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
return %0 : tensor<6x6x5x2xf32>
}
@@ -58,12 +58,12 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S3:.*]] = affine.min #[[$MAP0]](%[[ARG4]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, %[[S3]]] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x?xf32>
// CHECK: %[[EXTRACTED_SLICE_3:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32>
-// CHECK: %[[S4:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
+// CHECK: %[[S4:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S4]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x?x1xf32> into tensor<6x6x5x2xf32>
// -----
func.func @tile_winograd_filter(%arg0: tensor<2x3x1x5xf32>, %arg1: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
- %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x1x5xf32>) outs(%arg1 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ %0 = linalg.winograd_filter_transform mfr(F_4_3) ins(%arg0 : tensor<2x3x1x5xf32>) outs(%arg1 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
return %0 : tensor<6x1x5x2xf32>
}
@@ -87,13 +87,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
// CHECK: %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
-// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
+// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S3]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
// -----
func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ %0 = linalg.winograd_input_transform mfr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
return %0 : tensor<6x6x2x2x2x5xf32>
}
@@ -123,13 +123,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S6:.*]] = affine.apply #[[$MAP1]]()
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x5xf32>
// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK: %[[S7:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[S7:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S7]] into %[[ARG5]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
// -----
func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ %0 = linalg.winograd_input_transform mfr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
return %0 : tensor<6x6x2x2x2x5xf32>
}
@@ -167,13 +167,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S8:.*]] = affine.apply #[[$MAP1]]()
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], %[[S8]], 1] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<1x?x?x1xf32>
// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32>
-// CHECK: %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[S9:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x2x2x2x5xf32>
// -----
func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ %0 = linalg.winograd_input_transform mfr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
return %0 : tensor<6x6x2x2x2x5xf32>
}
@@ -213,13 +213,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S9:.*]] = affine.apply #[[$MAP2]]()
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S6]], %[[S7]], %[[ARG8]]] [2, %[[S8]], %[[S9]], %[[S5]]] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x?xf32>
// CHECK: %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32>
-// CHECK: %[[S10:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
+// CHECK: %[[S10:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x?xf32> into tensor<6x6x2x2x2x5xf32>
// -----
func.func @tile_winograd_input(%arg0: tensor<2x1x10x5xf32>, %arg1: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
- %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x1x10x5xf32>) outs(%arg1 : tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32>
+ %0 = linalg.winograd_input_transform mfr(F_4_3) ins(%arg0 : tensor<2x1x10x5xf32>) outs(%arg1 : tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32>
return %0 : tensor<1x6x1x2x2x5xf32>
}
@@ -258,13 +258,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S8:.*]] = affine.apply #[[$MAP2]]()
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
-// CHECK: %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
+// CHECK: %[[S9:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x1x1x1xf32> into tensor<1x6x1x2x2x5xf32>
// -----
func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x2x2xf32>, %arg1: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
- %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x2x2xf32>) outs(%arg1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ %0 = linalg.winograd_output_transform mfr(F_4_3) ins(%arg0 : tensor<6x6x2x2x2x2xf32>) outs(%arg1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
return %0 : tensor<2x8x8x2xf32>
}
@@ -298,7 +298,7 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x3x5xf32>, %arg1: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
- %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x3x5xf32>) outs(%arg1 : tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32>
+ %0 = linalg.winograd_output_transform mfr(F_4_3) ins(%arg0 : tensor<6x6x2x2x3x5xf32>) outs(%arg1 : tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32>
return %0 : tensor<3x8x8x5xf32>
}
@@ -346,7 +346,7 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @tile_winograd_output(%arg0 : tensor<6x1x2x1x3x5xf32>, %arg1: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
- %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x1x2x1x3x5xf32>) outs(%arg1 : tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32>
+ %0 = linalg.winograd_output_transform mfr(F_4_3) ins(%arg0 : tensor<6x1x2x1x3x5xf32>) outs(%arg1 : tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32>
return %0 : tensor<3x8x1x5xf32>
}
@@ -385,4 +385,4 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S7:.*]] = affine.apply #[[$MAP2]]()
// CHECK: %[[S8:.*]] = affine.apply #[[$MAP2]]()
// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
-// CHECK: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>
+// CHECK: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
index 1de861e653005..f2b77dde0fe6a 100644
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -14,10 +14,10 @@ module attributes {transform.with_named_sequence} {
}
// CHECK-LABEL: func.func @conv2d
-// CHECK: linalg.winograd_filter_transform m(4) r(3)
-// CHECK: linalg.winograd_input_transform m(4) r(3)
+// CHECK: linalg.winograd_filter_transform mfr(F_4_3)
+// CHECK: linalg.winograd_input_transform mfr(F_4_3)
// CHECK: linalg.batch_matmul
-// CHECK: linalg.winograd_output_transform m(4) r(3)
+// CHECK: linalg.winograd_output_transform mfr(F_4_3)
// -----
@@ -35,13 +35,13 @@ module attributes {transform.with_named_sequence} {
}
// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK: linalg.winograd_filter_transform m(4) r(3)
+// CHECK: linalg.winograd_filter_transform mfr(F_4_3)
// CHECK: tensor.pad
// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK: linalg.winograd_input_transform m(4) r(3)
+// CHECK: linalg.winograd_input_transform mfr(F_4_3)
// CHECK: tensor.pad
// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK: linalg.winograd_output_transform m(4) r(3)
+// CHECK: linalg.winograd_output_transform mfr(F_4_3)
// -----
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 16d06a7473272..ffa2faf491732 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -3,13 +3,13 @@
func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
%cst = arith.constant 0.000000e+00 : f32
%2 = tensor.empty() : tensor<6x6x5x2xf32>
- %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %3 = linalg.winograd_filter_transform mfr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
%padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst : f32
} : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
%4 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
- %5 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%4 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+ %5 = linalg.winograd_input_transform mfr(F_4_3) ins(%padded : tensor<2x14x14x5xf32>) outs(%4 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
%collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
%collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
%6 = tensor.empty() : tensor<36x18x2xf32>
@@ -20,7 +20,7 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst : f32
} : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
- %9 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+ %9 = linalg.winograd_output_transform mfr(F_4_3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
%extracted_slice = tensor.extract_slice %9[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
return %extracted_slice : tensor<2x9x9x2xf32>
}
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
index 0040d81a2d24e..14427e6c2c7c7 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -9,16 +9,16 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
// CHECK-NEXT: return %[[S9]] : tensor<2x4x4x2xf32>
// CHECK-NEXT: }
@@ -33,16 +33,16 @@ func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_2_5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform mfr(F_2_5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_2_5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
// CHECK-NEXT: return %[[S9]] : tensor<2x2x2x2xf32>
// CHECK-NEXT: }
@@ -57,16 +57,16 @@ func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S7]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
// CHECK-NEXT: return %[[S9]] : tensor<2x1x4x2xf32>
// CHECK-NEXT: }
@@ -81,16 +81,16 @@ func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S7]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
// CHECK-NEXT: return %[[S9]] : tensor<2x4x1x2xf32>
// CHECK-NEXT: }
@@ -105,16 +105,16 @@ func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf3
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
// CHECK-NEXT: return %[[S9]] : tensor<2x8x8x2xf32>
// CHECK-NEXT: }
@@ -129,13 +129,13 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
// CHECK-NEXT: ^bb0
// CHECK-NEXT: tensor.yield %[[CST]] : f32
// CHECK-NEXT: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
@@ -146,7 +146,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
// CHECK-NEXT: ^bb0
// CHECK-NEXT: tensor.yield %[[CST]] : f32
// CHECK-NEXT: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S7]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
// CHECK-NEXT: }
@@ -162,16 +162,16 @@ func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
-// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
// CHECK-NEXT: return %[[S7]] : tensor<2x4x4x2xf32>
// CHECK-NEXT: }
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 738648b8ccdcf..3160cad9b30bf 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -211,8 +212,8 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
static void applyWinogradConv2D(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
- populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
- populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
+ populateWinogradConv2DPatterns(patterns, WinogradConv2DFmr::F_4_3);
+ populateWinogradConv2DPatterns(patterns, WinogradConv2DFmr::F_2_5);
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
>From e65c956d77e3ca3d9db30e409c99a639c1bef228 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 18 Jun 2025 10:17:09 +0100
Subject: [PATCH 2/4] add comments for public APIs
---
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 69e09f6b32c2d..9e29611ae601f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -102,8 +102,15 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
namespace mlir {
namespace linalg {
+
+/// Converts the given `m` and `r` parameters to a WinogradConv2DFmr enumeration
+/// value.
WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r);
+
+/// Converts the given WinogradConv2DFmr enumeration value to a pair of
+/// m and r parameters.
std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
+
} // namespace linalg
} // namespace mlir
>From 34f29d0a8dc0d3a85fb29a9b08da0e0229d84b59 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 18 Jun 2025 22:52:30 +0100
Subject: [PATCH 3/4] address comments and fix typo in test cases
---
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 2 +-
.../mlir/Dialect/Linalg/IR/LinalgEnums.td | 7 +---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 6 +--
.../Linalg/Transforms/WinogradConv2D.cpp | 3 --
.../transform-tile-and-winograd-rewrite.mlir | 6 +--
.../Linalg/transform-tile-winograd.mlir | 36 ++++++++--------
.../Linalg/transform-winograd-conv2d.mlir | 24 +++++------
.../Linalg/winograd-conv2d-rewrite.mlir | 6 +--
mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 42 +++++++++----------
9 files changed, 62 insertions(+), 70 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 9e29611ae601f..4f5fea107f07b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -105,7 +105,7 @@ namespace linalg {
/// Converts the given `m` and `r` parameters to a WinogradConv2DFmr enumeration
/// value.
-WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r);
+std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r);
/// Converts the given WinogradConv2DFmr enumeration value to a pair of
/// m and r parameters.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 8c98c0b8b8683..1109db973f522 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -129,14 +129,11 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
/// For example, when m = 2 and r = 3, we know its input size is 4.
/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
/// 2x2 output result.
-def WinogradConv2DFmr : I32EnumAttr<"WinogradConv2DFmr",
- "Winograd Conv2D F(m, r)",
- [
+def WinogradConv2DFmr : I32EnumAttr<"WinogradConv2DFmr", "", [
I32EnumAttrCase<"F_2_3", 0>,
I32EnumAttrCase<"F_4_3", 1>,
I32EnumAttrCase<"F_2_5", 2>,
- I32EnumAttrCase<"Unknown", -1>,
- ]>{
+]>{
let cppNamespace = "mlir::linalg";
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ac592ad808311..b2639edb0d0f5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3631,14 +3631,14 @@ verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
namespace mlir {
namespace linalg {
-WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r) {
+std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
if (m == 2 && r == 3)
return WinogradConv2DFmr::F_2_3;
if (m == 4 && r == 3)
return WinogradConv2DFmr::F_4_3;
if (m == 2 && r == 5)
return WinogradConv2DFmr::F_2_5;
- return WinogradConv2DFmr::Unknown;
+ return std::nullopt;
}
std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
@@ -3649,8 +3649,6 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
return {4, 3};
case WinogradConv2DFmr::F_2_5:
return {2, 5};
- default:
- return {-1, -1};
}
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index a4f835350fb52..f0df2f01ef785 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -952,9 +952,6 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
return rewriter.notifyMatchFailure(
convOp, "only support filter (r x r), (r x 1) or (1 x r)");
- if (fmr == WinogradConv2DFmr::Unknown)
- return failure();
-
// All the criterias are satisfied. We can do Winograd Conv2D.
Location loc = convOp.getLoc();
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index fe5124eb251d9..445ded4bfcafb 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -350,16 +350,16 @@ module attributes {transform.with_named_sequence} {
func.func @conv2d_mx1_rx1_2(%arg0: tensor<2x6x2x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<6x1x5x2xf32>
- %1 = linalg.winograd_filter_transform mfr(F_4_3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
%2 = tensor.empty() : tensor<6x1x1x2x2x5xf32>
- %3 = linalg.winograd_input_transform mfr(F_4_3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
+ %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
%collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
%collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x2x2x5xf32> into tensor<6x4x5xf32>
%4 = tensor.empty() : tensor<6x4x2xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
%6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
%expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2] : tensor<6x4x2xf32> into tensor<6x1x1x2x2x2xf32>
- %7 = linalg.winograd_output_transform mfr(F_4_3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
+ %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
return %7 : tensor<2x4x2x2xf32>
}
diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
index 688264cc9a7aa..beb8d0b125738 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -transform-interpreter --split-input-file | FileCheck %s
func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
- %0 = linalg.winograd_filter_transform mfr(F_4_3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
return %0 : tensor<6x6x5x2xf32>
}
@@ -25,13 +25,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
// CHECK: %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
-// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
+// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S3]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
// -----
func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
- %0 = linalg.winograd_filter_transform mfr(F_4_3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
return %0 : tensor<6x6x5x2xf32>
}
@@ -58,12 +58,12 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S3:.*]] = affine.min #[[$MAP0]](%[[ARG4]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, %[[S3]]] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x?xf32>
// CHECK: %[[EXTRACTED_SLICE_3:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32>
-// CHECK: %[[S4:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
+// CHECK: %[[S4:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S4]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x?x1xf32> into tensor<6x6x5x2xf32>
// -----
func.func @tile_winograd_filter(%arg0: tensor<2x3x1x5xf32>, %arg1: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
- %0 = linalg.winograd_filter_transform mfr(F_4_3) ins(%arg0 : tensor<2x3x1x5xf32>) outs(%arg1 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ %0 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg0 : tensor<2x3x1x5xf32>) outs(%arg1 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
return %0 : tensor<6x1x5x2xf32>
}
@@ -87,13 +87,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
// CHECK: %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
-// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
+// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S3]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
// -----
func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
- %0 = linalg.winograd_input_transform mfr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
return %0 : tensor<6x6x2x2x2x5xf32>
}
@@ -123,13 +123,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S6:.*]] = affine.apply #[[$MAP1]]()
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x5xf32>
// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK: %[[S7:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[S7:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S7]] into %[[ARG5]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
// -----
func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
- %0 = linalg.winograd_input_transform mfr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
return %0 : tensor<6x6x2x2x2x5xf32>
}
@@ -167,13 +167,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S8:.*]] = affine.apply #[[$MAP1]]()
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], %[[S8]], 1] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<1x?x?x1xf32>
// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32>
-// CHECK: %[[S9:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[S9:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x2x2x2x5xf32>
// -----
func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
- %0 = linalg.winograd_input_transform mfr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
return %0 : tensor<6x6x2x2x2x5xf32>
}
@@ -213,13 +213,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S9:.*]] = affine.apply #[[$MAP2]]()
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S6]], %[[S7]], %[[ARG8]]] [2, %[[S8]], %[[S9]], %[[S5]]] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x?xf32>
// CHECK: %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32>
-// CHECK: %[[S10:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
+// CHECK: %[[S10:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x?xf32> into tensor<6x6x2x2x2x5xf32>
// -----
func.func @tile_winograd_input(%arg0: tensor<2x1x10x5xf32>, %arg1: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
- %0 = linalg.winograd_input_transform mfr(F_4_3) ins(%arg0 : tensor<2x1x10x5xf32>) outs(%arg1 : tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32>
+ %0 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x1x10x5xf32>) outs(%arg1 : tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32>
return %0 : tensor<1x6x1x2x2x5xf32>
}
@@ -258,13 +258,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S8:.*]] = affine.apply #[[$MAP2]]()
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
-// CHECK: %[[S9:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
+// CHECK: %[[S9:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x1x1x1xf32> into tensor<1x6x1x2x2x5xf32>
// -----
func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x2x2xf32>, %arg1: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
- %0 = linalg.winograd_output_transform mfr(F_4_3) ins(%arg0 : tensor<6x6x2x2x2x2xf32>) outs(%arg1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x2x2x2x2xf32>) outs(%arg1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
return %0 : tensor<2x8x8x2xf32>
}
@@ -298,7 +298,7 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x3x5xf32>, %arg1: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
- %0 = linalg.winograd_output_transform mfr(F_4_3) ins(%arg0 : tensor<6x6x2x2x3x5xf32>) outs(%arg1 : tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32>
+ %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x6x2x2x3x5xf32>) outs(%arg1 : tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32>
return %0 : tensor<3x8x8x5xf32>
}
@@ -346,7 +346,7 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @tile_winograd_output(%arg0 : tensor<6x1x2x1x3x5xf32>, %arg1: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
- %0 = linalg.winograd_output_transform mfr(F_4_3) ins(%arg0 : tensor<6x1x2x1x3x5xf32>) outs(%arg1 : tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32>
+ %0 = linalg.winograd_output_transform fmr(F_4_3) ins(%arg0 : tensor<6x1x2x1x3x5xf32>) outs(%arg1 : tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32>
return %0 : tensor<3x8x1x5xf32>
}
@@ -385,4 +385,4 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S7:.*]] = affine.apply #[[$MAP2]]()
// CHECK: %[[S8:.*]] = affine.apply #[[$MAP2]]()
// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
-// CHECK: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>
+// CHECK: %[[S9:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
index f2b77dde0fe6a..e0ead54c956fc 100644
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s
func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
@@ -8,16 +8,16 @@ func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ %1 = transform.structured.winograd_conv2d %0 { fmr = 1: i32 } : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
// CHECK-LABEL: func.func @conv2d
-// CHECK: linalg.winograd_filter_transform mfr(F_4_3)
-// CHECK: linalg.winograd_input_transform mfr(F_4_3)
+// CHECK: linalg.winograd_filter_transform fmr(F_4_3)
+// CHECK: linalg.winograd_input_transform fmr(F_4_3)
// CHECK: linalg.batch_matmul
-// CHECK: linalg.winograd_output_transform mfr(F_4_3)
+// CHECK: linalg.winograd_output_transform fmr(F_4_3)
// -----
@@ -29,19 +29,19 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ %1 = transform.structured.winograd_conv2d %0 { fmr = 1: i32 } : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK: linalg.winograd_filter_transform mfr(F_4_3)
+// CHECK: linalg.winograd_filter_transform fmr(F_4_3)
// CHECK: tensor.pad
// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK: linalg.winograd_input_transform mfr(F_4_3)
+// CHECK: linalg.winograd_input_transform fmr(F_4_3)
// CHECK: tensor.pad
// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK: linalg.winograd_output_transform mfr(F_4_3)
+// CHECK: linalg.winograd_output_transform fmr(F_4_3)
// -----
@@ -54,7 +54,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @+1 {{this operation is not supported to convert to Winograd Conv2D}}
- %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ %1 = transform.structured.winograd_conv2d %0 { fmr = 1: i32 } : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -70,7 +70,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @+1 {{apply Winograd Conv2D failed}}
- %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ %1 = transform.structured.winograd_conv2d %0 { fmr = 1: i32 } : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -86,7 +86,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @+1 {{apply Winograd Conv2D failed}}
- %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ %1 = transform.structured.winograd_conv2d %0 { fmr = 1: i32 } : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index ffa2faf491732..c7b0bd51308ba 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -3,13 +3,13 @@
func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
%cst = arith.constant 0.000000e+00 : f32
%2 = tensor.empty() : tensor<6x6x5x2xf32>
- %3 = linalg.winograd_filter_transform mfr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %3 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
%padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst : f32
} : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
%4 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
- %5 = linalg.winograd_input_transform mfr(F_4_3) ins(%padded : tensor<2x14x14x5xf32>) outs(%4 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+ %5 = linalg.winograd_input_transform fmr(F_4_3) ins(%padded : tensor<2x14x14x5xf32>) outs(%4 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
%collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
%collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
%6 = tensor.empty() : tensor<36x18x2xf32>
@@ -20,7 +20,7 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst : f32
} : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
- %9 = linalg.winograd_output_transform mfr(F_4_3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+ %9 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
%extracted_slice = tensor.extract_slice %9[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
return %extracted_slice : tensor<2x9x9x2xf32>
}
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
index 14427e6c2c7c7..e80fa6b4af944 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -9,16 +9,16 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
// CHECK-NEXT: return %[[S9]] : tensor<2x4x4x2xf32>
// CHECK-NEXT: }
@@ -33,16 +33,16 @@ func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_2_5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_2_5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform mfr(F_2_5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform fmr(F_2_5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_2_5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform fmr(F_2_5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
// CHECK-NEXT: return %[[S9]] : tensor<2x2x2x2xf32>
// CHECK-NEXT: }
@@ -57,16 +57,16 @@ func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S7]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
// CHECK-NEXT: return %[[S9]] : tensor<2x1x4x2xf32>
// CHECK-NEXT: }
@@ -81,16 +81,16 @@ func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S7]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
// CHECK-NEXT: return %[[S9]] : tensor<2x4x1x2xf32>
// CHECK-NEXT: }
@@ -105,16 +105,16 @@ func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf3
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
// CHECK-NEXT: return %[[S9]] : tensor<2x8x8x2xf32>
// CHECK-NEXT: }
@@ -129,13 +129,13 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
// CHECK-NEXT: ^bb0
// CHECK-NEXT: tensor.yield %[[CST]] : f32
// CHECK-NEXT: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
@@ -146,7 +146,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
// CHECK-NEXT: ^bb0
// CHECK-NEXT: tensor.yield %[[CST]] : f32
// CHECK-NEXT: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S7]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
// CHECK-NEXT: }
@@ -162,16 +162,16 @@ func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
-// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform mfr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform fmr(F_4_3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform mfr(F_4_3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform fmr(F_4_3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform mfr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform fmr(F_4_3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
// CHECK-NEXT: return %[[S7]] : tensor<2x4x4x2xf32>
// CHECK-NEXT: }
>From b7631e82399138da9a5970ef4fb0df741d9805aa Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 25 Jun 2025 10:30:24 +0100
Subject: [PATCH 4/4] document how to add transform matrices
---
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index f0df2f01ef785..c61b23c63dc56 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -39,6 +39,15 @@ namespace {
///
/// The following tables define these constant transformation matrices for
/// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+///
+/// To add more transformation matrices, we need to add the following
+/// items:
+/// 1. Add the constant transformation matrix to the corresponding
+/// G, GT, BT, B, AT, or A array.
+/// 2. Add the corresponding TransformMatrix to the GMatrices, GTMatrices,
+/// BTMatrices, BMatrices, ATMatrices, or AMatrices map.
+/// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
+///
constexpr float G_2x2_3x3[] = {
-1, 0, 0,
1./2, -1./2, 1./2,
More information about the Mlir-commits
mailing list