[Mlir-commits] [mlir] [MLIR][Linalg] Fix winograd op lowering for types smaller than f32 (PR #158500)
Isaac Nudelman
llvmlistbot at llvm.org
Sun Sep 14 11:02:10 PDT 2025
https://github.com/nuudlman created https://github.com/llvm/llvm-project/pull/158500
The winograd transform constant array is always emitted as f32, but previously the creation would pass through the original type. If this type was smaller (like f16), you would get an assertion failure during attribute creation.
This fixes this by always promoting the type of the winograd constants to f32 and adding a test for this case.
>From d4d2121b41d16a9658b8972dc47b79c15b50ca5f Mon Sep 17 00:00:00 2001
From: Isaac Nudelman <isaac.nudelman at utexas.edu>
Date: Fri, 12 Sep 2025 20:44:34 -0500
Subject: [PATCH 1/4] add a concept of a fix
---
.../Linalg/Transforms/WinogradConv2D.cpp | 47 ++++++++++---------
.../Linalg/winograd-conv2d-rewrite.mlir | 18 +++++++
2 files changed, 43 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index b80b27fe5fcc5..288c8ada0c8eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -46,46 +46,46 @@ namespace {
/// BTMatrices, BMatrices, ATMatrices, or AMatrices map.
/// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
///
-constexpr float G_2x2_3x3[] = {
+constexpr double G_2x2_3x3[] = {
-1, 0, 0,
1./2, -1./2, 1./2,
1./2, 1./2, 1./2,
0, 0, 1
};
-constexpr float GT_2x2_3x3[] = {
+constexpr double GT_2x2_3x3[] = {
-1, 1./2, 1./2, 0,
0, -1./2, 1./2, 0,
0, 1./2, 1./2, 1
};
-constexpr float BT_2x2_3x3[] = {
+constexpr double BT_2x2_3x3[] = {
-1, 0, 1, 0,
0, -1, 1, 0,
0, 1, 1, 0,
0, -1, 0, 1
};
-constexpr float B_2x2_3x3[] = {
+constexpr double B_2x2_3x3[] = {
-1, 0, 0, 0,
0, -1, 1, -1,
1, 1, 1, 0,
0, 0, 0, 1
};
-constexpr float AT_2x2_3x3[] = {
+constexpr double AT_2x2_3x3[] = {
1, 1, 1, 0,
0, -1, 1, 1
};
-constexpr float A_2x2_3x3[] = {
+constexpr double A_2x2_3x3[] = {
1, 0,
1, -1,
1, 1,
0, 1
};
-constexpr float G_4x4_3x3[] = {
+constexpr double G_4x4_3x3[] = {
1, 0, 0,
-1./3, 1./3, -1./3,
-1./3, -1./3, -1./3,
@@ -94,13 +94,13 @@ constexpr float G_4x4_3x3[] = {
0, 0, 1
};
-constexpr float GT_4x4_3x3[] = {
+constexpr double GT_4x4_3x3[] = {
1, -1./3, -1./3, 1./12, 1./12, 0,
0, 1./3, -1./3, -1./6, 1./6, 0,
0, -1./3, -1./3, 1./3, 1./3, 1
};
-constexpr float BT_4x4_3x3[] = {
+constexpr double BT_4x4_3x3[] = {
1./4, 0, -5./16, 0, 1./16, 0,
0, 1./4, -1./4, -1./16, 1./16, 0,
0, -1./4, -1./4, 1./16, 1./16, 0,
@@ -109,7 +109,7 @@ constexpr float BT_4x4_3x3[] = {
0, 1./4, 0, -5./16, 0, 1./16
};
-constexpr float B_4x4_3x3[] = {
+constexpr double B_4x4_3x3[] = {
1./4, 0, 0, 0, 0, 0,
0, 1./4, -1./4, 1./4, -1./4, 1./4,
-5./16, -1./4, -1./4, -1./8, -1./8, 0,
@@ -118,14 +118,14 @@ constexpr float B_4x4_3x3[] = {
0, 0, 0, 0, 0, 1./16
};
-constexpr float AT_4x4_3x3[] = {
+constexpr double AT_4x4_3x3[] = {
1./8, 1./4, 1./4, 1./8, 1./8, 0,
0, -1./4, 1./4, -1./4, 1./4, 0,
0, 1./4, 1./4, 1./2, 1./2, 0,
0, -1./4, 1./4, -1, 1, 1./2
};
-constexpr float A_4x4_3x3[] = {
+constexpr double A_4x4_3x3[] = {
1./8, 0, 0, 0,
1./4, -1./4, 1./4, -1./4,
1./4, 1./4, 1./4, 1./4,
@@ -134,7 +134,7 @@ constexpr float A_4x4_3x3[] = {
0, 0, 0, 1./2
};
-constexpr float G_2x2_5x5[] = {
+constexpr double G_2x2_5x5[] = {
1, 0, 0, 0, 0,
1./6, -1./6, 1./6, -1./6, 1./6,
-1./6, -1./6, -1./6, -1./6, -1./6,
@@ -143,7 +143,7 @@ constexpr float G_2x2_5x5[] = {
0, 0, 0, 0, 1
};
-constexpr float GT_2x2_5x5[] = {
+constexpr double GT_2x2_5x5[] = {
1, 1./6, -1./6, -4./15, 1./60, 0,
0, -1./6, -1./6, 2./15, 1./30, 0,
0, 1./6, -1./6, -1./15, 1./15, 0,
@@ -151,7 +151,7 @@ constexpr float GT_2x2_5x5[] = {
0, 1./6, -1./6, -1./60, 4./15, 1
};
-constexpr float BT_2x2_5x5[] = {
+constexpr double BT_2x2_5x5[] = {
1./8, 3./16, -1./4, -3./16, 1./8, 0,
0, 1./8, 1./16, -5./16, 1./8, 0,
0, -1./8, -5./16, -1./16, 1./8, 0,
@@ -160,7 +160,7 @@ constexpr float BT_2x2_5x5[] = {
0, 1./8, 3./16, -1./4, -3./16, 1./8
};
-constexpr float B_2x2_5x5[] = {
+constexpr double B_2x2_5x5[] = {
1./8, 0, 0, 0, 0, 0,
3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
-1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
@@ -169,12 +169,12 @@ constexpr float B_2x2_5x5[] = {
0, 0, 0, 0, 0, 1./8
};
-constexpr float AT_2x2_5x5[] = {
+constexpr double AT_2x2_5x5[] = {
1./2, 1, 1, 2, 1, 0,
0, -1, 1, -1, 2, 1./2
};
-constexpr float A_2x2_5x5[] = {
+constexpr double A_2x2_5x5[] = {
1./2, 0,
1, -1,
1, 1,
@@ -186,11 +186,12 @@ constexpr float A_2x2_5x5[] = {
/// Structure to keep information of constant transform matrices.
struct TransformMatrix {
- TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ TransformMatrix(ArrayRef<double> table, int64_t rows, int64_t cols,
int64_t scalarFactor = 1)
- : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+ : table(llvm::map_to_vector(table, [](double val) { return APFloat(val); })), rows(rows), cols(cols), scalarFactor(scalarFactor) {
+ }
- const float *table;
+ SmallVector<APFloat> table;
int64_t rows;
int64_t cols;
int64_t scalarFactor;
@@ -199,7 +200,9 @@ struct TransformMatrix {
/// Utility function to convert constant array to arith.constant Value.
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
TransformMatrix transform, Type type) {
- ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+ assert(type.isFloat());
+ assert(transform.table.size() == (transform.rows * transform.cols));
+ ArrayRef<APFloat> constVec(transform.table.data(), transform.rows * transform.cols);
return arith::ConstantOp::create(
builder, loc,
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index c7b0bd51308ba..0c7d4e1d23f34 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -127,3 +127,21 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<6x6x5x2xf16>
+ %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf16>) outs(%0 : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16> // no-crash
+ %2 = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+ %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x6x5xf16>) outs(%2 : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16> // no-crash
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+ %4 = tensor.empty() : tensor<36x2x2xf32>
+ %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+ %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%5 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+ %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+ %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+ return %7 : tensor<2x4x4x2xf32>
+}
>From 4ca18ad309a18282460bbc1b472330012fec1cd7 Mon Sep 17 00:00:00 2001
From: Isaac Nudelman <isaac.nudelman at utexas.edu>
Date: Sun, 14 Sep 2025 11:04:22 -0500
Subject: [PATCH 2/4] Always promote winograd lowering to f32
---
.../Linalg/Transforms/WinogradConv2D.cpp | 69 +++++++++----------
1 file changed, 34 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 288c8ada0c8eb..2e884f6f79ef6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -46,46 +46,46 @@ namespace {
/// BTMatrices, BMatrices, ATMatrices, or AMatrices map.
/// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
///
-constexpr double G_2x2_3x3[] = {
+constexpr float G_2x2_3x3[] = {
-1, 0, 0,
1./2, -1./2, 1./2,
1./2, 1./2, 1./2,
0, 0, 1
};
-constexpr double GT_2x2_3x3[] = {
+constexpr float GT_2x2_3x3[] = {
-1, 1./2, 1./2, 0,
0, -1./2, 1./2, 0,
0, 1./2, 1./2, 1
};
-constexpr double BT_2x2_3x3[] = {
+constexpr float BT_2x2_3x3[] = {
-1, 0, 1, 0,
0, -1, 1, 0,
0, 1, 1, 0,
0, -1, 0, 1
};
-constexpr double B_2x2_3x3[] = {
+constexpr float B_2x2_3x3[] = {
-1, 0, 0, 0,
0, -1, 1, -1,
1, 1, 1, 0,
0, 0, 0, 1
};
-constexpr double AT_2x2_3x3[] = {
+constexpr float AT_2x2_3x3[] = {
1, 1, 1, 0,
0, -1, 1, 1
};
-constexpr double A_2x2_3x3[] = {
+constexpr float A_2x2_3x3[] = {
1, 0,
1, -1,
1, 1,
0, 1
};
-constexpr double G_4x4_3x3[] = {
+constexpr float G_4x4_3x3[] = {
1, 0, 0,
-1./3, 1./3, -1./3,
-1./3, -1./3, -1./3,
@@ -94,13 +94,13 @@ constexpr double G_4x4_3x3[] = {
0, 0, 1
};
-constexpr double GT_4x4_3x3[] = {
+constexpr float GT_4x4_3x3[] = {
1, -1./3, -1./3, 1./12, 1./12, 0,
0, 1./3, -1./3, -1./6, 1./6, 0,
0, -1./3, -1./3, 1./3, 1./3, 1
};
-constexpr double BT_4x4_3x3[] = {
+constexpr float BT_4x4_3x3[] = {
1./4, 0, -5./16, 0, 1./16, 0,
0, 1./4, -1./4, -1./16, 1./16, 0,
0, -1./4, -1./4, 1./16, 1./16, 0,
@@ -109,7 +109,7 @@ constexpr double BT_4x4_3x3[] = {
0, 1./4, 0, -5./16, 0, 1./16
};
-constexpr double B_4x4_3x3[] = {
+constexpr float B_4x4_3x3[] = {
1./4, 0, 0, 0, 0, 0,
0, 1./4, -1./4, 1./4, -1./4, 1./4,
-5./16, -1./4, -1./4, -1./8, -1./8, 0,
@@ -118,14 +118,14 @@ constexpr double B_4x4_3x3[] = {
0, 0, 0, 0, 0, 1./16
};
-constexpr double AT_4x4_3x3[] = {
+constexpr float AT_4x4_3x3[] = {
1./8, 1./4, 1./4, 1./8, 1./8, 0,
0, -1./4, 1./4, -1./4, 1./4, 0,
0, 1./4, 1./4, 1./2, 1./2, 0,
0, -1./4, 1./4, -1, 1, 1./2
};
-constexpr double A_4x4_3x3[] = {
+constexpr float A_4x4_3x3[] = {
1./8, 0, 0, 0,
1./4, -1./4, 1./4, -1./4,
1./4, 1./4, 1./4, 1./4,
@@ -134,7 +134,7 @@ constexpr double A_4x4_3x3[] = {
0, 0, 0, 1./2
};
-constexpr double G_2x2_5x5[] = {
+constexpr float G_2x2_5x5[] = {
1, 0, 0, 0, 0,
1./6, -1./6, 1./6, -1./6, 1./6,
-1./6, -1./6, -1./6, -1./6, -1./6,
@@ -143,7 +143,7 @@ constexpr double G_2x2_5x5[] = {
0, 0, 0, 0, 1
};
-constexpr double GT_2x2_5x5[] = {
+constexpr float GT_2x2_5x5[] = {
1, 1./6, -1./6, -4./15, 1./60, 0,
0, -1./6, -1./6, 2./15, 1./30, 0,
0, 1./6, -1./6, -1./15, 1./15, 0,
@@ -151,7 +151,7 @@ constexpr double GT_2x2_5x5[] = {
0, 1./6, -1./6, -1./60, 4./15, 1
};
-constexpr double BT_2x2_5x5[] = {
+constexpr float BT_2x2_5x5[] = {
1./8, 3./16, -1./4, -3./16, 1./8, 0,
0, 1./8, 1./16, -5./16, 1./8, 0,
0, -1./8, -5./16, -1./16, 1./8, 0,
@@ -160,7 +160,7 @@ constexpr double BT_2x2_5x5[] = {
0, 1./8, 3./16, -1./4, -3./16, 1./8
};
-constexpr double B_2x2_5x5[] = {
+constexpr float B_2x2_5x5[] = {
1./8, 0, 0, 0, 0, 0,
3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
-1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
@@ -169,12 +169,12 @@ constexpr double B_2x2_5x5[] = {
0, 0, 0, 0, 0, 1./8
};
-constexpr double AT_2x2_5x5[] = {
+constexpr float AT_2x2_5x5[] = {
1./2, 1, 1, 2, 1, 0,
0, -1, 1, -1, 2, 1./2
};
-constexpr double A_2x2_5x5[] = {
+constexpr float A_2x2_5x5[] = {
1./2, 0,
1, -1,
1, 1,
@@ -186,12 +186,12 @@ constexpr double A_2x2_5x5[] = {
/// Structure to keep information of constant transform matrices.
struct TransformMatrix {
- TransformMatrix(ArrayRef<double> table, int64_t rows, int64_t cols,
+ TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
int64_t scalarFactor = 1)
- : table(llvm::map_to_vector(table, [](double val) { return APFloat(val); })), rows(rows), cols(cols), scalarFactor(scalarFactor) {
+ : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {
}
- SmallVector<APFloat> table;
+ ArrayRef<float> table;
int64_t rows;
int64_t cols;
int64_t scalarFactor;
@@ -199,17 +199,14 @@ struct TransformMatrix {
/// Utility function to convert constant array to arith.constant Value.
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
- TransformMatrix transform, Type type) {
- assert(type.isFloat());
- assert(transform.table.size() == (transform.rows * transform.cols));
- ArrayRef<APFloat> constVec(transform.table.data(), transform.rows * transform.cols);
-
+ TransformMatrix transform) {
+ assert(transform.table.size() == static_cast<size_t>(transform.rows * transform.cols));
+ ArrayRef<float> constVec(transform.table.data(), transform.rows * transform.cols);
+ SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
return arith::ConstantOp::create(
builder, loc,
DenseFPElementsAttr::get(
- RankedTensorType::get(
- SmallVector<int64_t>{transform.rows, transform.cols}, type),
- constVec));
+ RankedTensorType::get(shape, builder.getF32Type()), constVec));
}
/// Extract height x width data from 4D tensors.
@@ -407,7 +404,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
- Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+ Value G = create2DTransformMatrix(builder, loc, GMatrix);
// Multiply G x g.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{G, extractFilter},
@@ -430,7 +427,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
- Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+ Value GT = create2DTransformMatrix(builder, loc, GTMatrix);
// Multiply u = (G x g) x GT.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{matmulRetValue, GT},
@@ -500,6 +497,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
auto inputType = cast<ShapedType>(input.getType());
Type elementType = inputType.getElementType();
+ // assert(elementType.isF32() && "NYI: support non-f32");
auto inputShape = inputType.getShape(); // N, H, W, C
int64_t inputN = inputShape[0];
int64_t inputC = inputShape[3];
@@ -555,7 +553,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
Value BT =
- create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+ create2DTransformMatrix(builder, loc, BTMatrix);
// Multiply BT x d.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{BT, matmulRetValue},
@@ -578,7 +576,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
Value B =
- create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+ create2DTransformMatrix(builder, loc, BMatrix);
// Multiply v = (BT x d) x B.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{matmulRetValue, B},
@@ -723,6 +721,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
auto valueType = cast<ShapedType>(value.getType());
Type elementType = valueType.getElementType();
+ // assert(elementType.isF32() && "NYI: support non-f32");
auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
int64_t valueH = valueShape[0];
int64_t valueW = valueShape[1];
@@ -786,7 +785,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
}
- Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
+ Value AT = create2DTransformMatrix(builder, loc, ATMatrix);
// Multiply AT x m.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{AT, matmulRetValue},
@@ -805,7 +804,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
}
- Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
+ Value A = create2DTransformMatrix(builder, loc, AMatrix);
// Multiply y = (AT x m) x A.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{matmulRetValue, A},
>From f94946b4a99d56b30aea8d7b0bf663d881c744a8 Mon Sep 17 00:00:00 2001
From: Isaac Nudelman <isaac.nudelman at utexas.edu>
Date: Sun, 14 Sep 2025 12:25:12 -0500
Subject: [PATCH 3/4] Fix test
---
.../Linalg/winograd-conv2d-rewrite.mlir | 98 +++++++++++++++++++
1 file changed, 98 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 0c7d4e1d23f34..4bcb9b0c2c465 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -145,3 +145,101 @@ func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3
%7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
return %7 : tensor<2x4x4x2xf32>
}
+
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_type_promotion(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x6x6x5xf16>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<2x3x3x5xf16>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<1xf32>,
+// CHECK-SAME: %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK-NEXT: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT: %[[VAL_20:.*]] = tensor.extract_slice %[[ARG1]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_11]], %[[VAL_18]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf16> to tensor<3x3xf16>
+// CHECK-NEXT: %[[VAL_21:.*]] = tensor.empty() : tensor<6x3xf16>
+// CHECK-NEXT: %[[VAL_22:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_21]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT: %[[VAL_23:.*]] = linalg.matmul ins(%[[VAL_6]], %[[VAL_20]] : tensor<6x3xf32>, tensor<3x3xf16>) outs(%[[VAL_22]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT: %[[VAL_24:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_25:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_24]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_26:.*]] = linalg.matmul ins(%[[VAL_23]], %[[VAL_5]] : tensor<6x3xf16>, tensor<3x6xf32>) outs(%[[VAL_25]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_27:.*]] = tensor.insert_slice %[[VAL_26]] into %[[VAL_19]]{{\[}}%[[VAL_11]], %[[VAL_11]], %[[VAL_18]], %[[VAL_15]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x5x2xf16>
+// CHECK-NEXT: scf.yield %[[VAL_27]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_17]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[VAL_28:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_32:.*]] = scf.for %[[VAL_33:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_34:.*]] = %[[VAL_31]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_35:.*]] = scf.for %[[VAL_36:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_37:.*]] = %[[VAL_34]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_38:.*]] = scf.for %[[VAL_39:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_40:.*]] = %[[VAL_37]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_41:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_30]])
+// CHECK-NEXT: %[[VAL_42:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_33]])
+// CHECK-NEXT: %[[VAL_43:.*]] = tensor.extract_slice %[[ARG0]]{{\[}}%[[VAL_36]], %[[VAL_41]], %[[VAL_42]], %[[VAL_39]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf16> to tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_44:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_45:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_44]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_46:.*]] = linalg.matmul ins(%[[VAL_4]], %[[VAL_43]] : tensor<6x6xf32>, tensor<6x6xf16>) outs(%[[VAL_45]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_47:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_48:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_47]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_49:.*]] = linalg.matmul ins(%[[VAL_46]], %[[VAL_3]] : tensor<6x6xf16>, tensor<6x6xf32>) outs(%[[VAL_48]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_50:.*]] = tensor.insert_slice %[[VAL_49]] into %[[VAL_40]][0, 0, %[[VAL_30]], %[[VAL_33]], %[[VAL_36]], %[[VAL_39]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: scf.yield %[[VAL_50]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_38]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_35]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_32]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[VAL_51:.*]] = tensor.collapse_shape %[[VAL_14]] {{\[\[}}0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+// CHECK-NEXT: %[[VAL_52:.*]] = tensor.collapse_shape %[[VAL_29]] {{\[\[}}0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+// CHECK-NEXT: %[[VAL_53:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT: %[[VAL_54:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_53]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[VAL_55:.*]] = linalg.batch_matmul ins(%[[VAL_52]], %[[VAL_51]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[VAL_54]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[VAL_56:.*]] = tensor.expand_shape %[[VAL_55]] {{\[\[}}0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_59:.*]] = %[[ARG3]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_60:.*]] = scf.for %[[VAL_61:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_62:.*]] = %[[VAL_59]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_65:.*]] = %[[VAL_62]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_66:.*]] = scf.for %[[VAL_67:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_68:.*]] = %[[VAL_65]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_69:.*]] = tensor.extract_slice %[[VAL_56]][0, 0, %[[VAL_58]], %[[VAL_61]], %[[VAL_64]], %[[VAL_67]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6xf32>
+// CHECK-NEXT: %[[VAL_70:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_58]])
+// CHECK-NEXT: %[[VAL_71:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_61]])
+// CHECK-NEXT: %[[VAL_72:.*]] = tensor.extract_slice %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x4x4x2xf32> to tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_73:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT: %[[VAL_74:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_73]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT: %[[VAL_75:.*]] = linalg.matmul ins(%[[VAL_2]], %[[VAL_69]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[VAL_74]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT: %[[VAL_76:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_77:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_76]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_78:.*]] = linalg.matmul ins(%[[VAL_75]], %[[VAL_1]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[VAL_77]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_79:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_78]] : f32, tensor<4x4xf32>) outs(%[[VAL_72]] : tensor<4x4xf32>) {
+// CHECK-NEXT: ^bb0(%[[VAL_80:.*]]: f32, %[[VAL_81:.*]]: f32, %[[VAL_82:.*]]: f32):
+// CHECK-NEXT: %[[VAL_83:.*]] = arith.mulf %[[VAL_80]], %[[VAL_81]] : f32
+// CHECK-NEXT: %[[VAL_84:.*]] = arith.addf %[[VAL_83]], %[[VAL_82]] : f32
+// CHECK-NEXT: linalg.yield %[[VAL_84]] : f32
+// CHECK-NEXT: } -> tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_85:.*]] = tensor.insert_slice %[[VAL_79]] into %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT: scf.yield %[[VAL_85]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_66]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_63]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_60]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[VAL_57]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
>From eec3881e3220661e2c28e0b05d1901b1b85b4ded Mon Sep 17 00:00:00 2001
From: Isaac Nudelman <isaac.nudelman at utexas.edu>
Date: Sun, 14 Sep 2025 12:56:16 -0500
Subject: [PATCH 4/4] Remove debug asserts
---
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 2e884f6f79ef6..b875b24c8fda0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -497,7 +497,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
auto inputType = cast<ShapedType>(input.getType());
Type elementType = inputType.getElementType();
- // assert(elementType.isF32() && "NYI: support non-f32");
auto inputShape = inputType.getShape(); // N, H, W, C
int64_t inputN = inputShape[0];
int64_t inputC = inputShape[3];
@@ -721,7 +720,6 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
auto valueType = cast<ShapedType>(value.getType());
Type elementType = valueType.getElementType();
- // assert(elementType.isF32() && "NYI: support non-f32");
auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
int64_t valueH = valueShape[0];
int64_t valueW = valueShape[1];
More information about the Mlir-commits
mailing list